1
0
Fork 0
mirror of https://github.com/puma/puma.git synced 2022-11-09 13:48:40 -05:00
puma--puma/ext/puma_http11/org/jruby/puma/MiniSSL.java
2015-07-17 11:48:47 -07:00

392 lines
12 KiB
Java

package org.jruby.puma;
import org.jruby.Ruby;
import org.jruby.RubyBoolean;
import org.jruby.RubyClass;
import org.jruby.RubyModule;
import org.jruby.RubyObject;
import org.jruby.RubyString;
import org.jruby.anno.JRubyMethod;
import org.jruby.runtime.Block;
import org.jruby.runtime.ObjectAllocator;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ByteList;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import static javax.net.ssl.SSLEngineResult.Status;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus;
public class MiniSSL extends RubyObject {
private static ObjectAllocator ALLOCATOR = new ObjectAllocator() {
public IRubyObject allocate(Ruby runtime, RubyClass klass) {
return new MiniSSL(runtime, klass);
}
};
// set to true to switch on our low-fi trace logging
private static boolean DEBUG = false;
public static void createMiniSSL(Ruby runtime) {
RubyModule mPuma = runtime.defineModule("Puma");
RubyModule ssl = mPuma.defineModuleUnder("MiniSSL");
mPuma.defineClassUnder("SSLError",
runtime.getClass("IOError"),
runtime.getClass("IOError").getAllocator());
RubyClass eng = ssl.defineClassUnder("Engine",runtime.getObject(),ALLOCATOR);
eng.defineAnnotatedMethods(MiniSSL.class);
}
/**
* Fairly transparent wrapper around {@link java.nio.ByteBuffer} which adds the enhancements we need
*/
private static class MiniSSLBuffer {
ByteBuffer buffer;
private MiniSSLBuffer(int capacity) { buffer = ByteBuffer.allocate(capacity); }
private MiniSSLBuffer(byte[] initialContents) { buffer = ByteBuffer.wrap(initialContents); }
public void clear() { buffer.clear(); }
public void compact() { buffer.compact(); }
public void flip() { buffer.flip(); }
public boolean hasRemaining() { return buffer.hasRemaining(); }
public int position() { return buffer.position(); }
public ByteBuffer getRawBuffer() {
return buffer;
}
/**
* Writes bytes to the buffer after ensuring there's room
*/
public void put(byte[] bytes) {
if (buffer.remaining() < bytes.length) {
resize(buffer.limit() + bytes.length);
}
buffer.put(bytes);
}
/**
* Ensures that newCapacity bytes can be written to this buffer, only re-allocating if necessary
*/
public void resize(int newCapacity) {
if (newCapacity > buffer.capacity()) {
ByteBuffer dstTmp = ByteBuffer.allocate(newCapacity);
buffer.flip();
dstTmp.put(buffer);
buffer = dstTmp;
} else {
buffer.limit(newCapacity);
}
}
/**
* Drains the buffer to a ByteList, or returns null for an empty buffer
*/
public ByteList asByteList() {
buffer.flip();
if (!buffer.hasRemaining()) {
buffer.clear();
return null;
}
byte[] bss = new byte[buffer.limit()];
buffer.get(bss);
buffer.clear();
return new ByteList(bss);
}
@Override
public String toString() { return buffer.toString(); }
}
private SSLEngine engine;
private MiniSSLBuffer inboundNetData;
private MiniSSLBuffer outboundAppData;
private MiniSSLBuffer outboundNetData;
public MiniSSL(Ruby runtime, RubyClass klass) {
super(runtime, klass);
}
@JRubyMethod(meta = true)
public static IRubyObject server(ThreadContext context, IRubyObject recv, IRubyObject miniSSLContext) {
RubyClass klass = (RubyClass) recv;
return klass.newInstance(context,
new IRubyObject[] { miniSSLContext },
Block.NULL_BLOCK);
}
@JRubyMethod
public IRubyObject initialize(ThreadContext threadContext, IRubyObject miniSSLContext)
throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException, KeyManagementException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
char[] password = miniSSLContext.callMethod(threadContext, "keystore_pass").convertToString().asJavaString().toCharArray();
ks.load(new FileInputStream(miniSSLContext.callMethod(threadContext, "keystore").convertToString().asJavaString()),
password);
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, password);
SSLContext sslCtx = SSLContext.getInstance("TLS");
sslCtx.init(kmf.getKeyManagers(), null, null);
engine = sslCtx.createSSLEngine();
String[] protocols = new String[] { "TLSv1", "TLSv1.1", "TLSv1.2" };
engine.setEnabledProtocols(protocols);
engine.setUseClientMode(false);
SSLSession session = engine.getSession();
inboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());
outboundAppData = new MiniSSLBuffer(session.getApplicationBufferSize());
outboundAppData.flip();
outboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());
return this;
}
@JRubyMethod
public IRubyObject inject(IRubyObject arg) {
try {
byte[] bytes = arg.convertToString().getBytes();
log("Net Data post pre-inject: " + inboundNetData);
inboundNetData.put(bytes);
log("Net Data post post-inject: " + inboundNetData);
log("inject(): " + bytes.length + " encrypted bytes from request");
return this;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
private enum SSLOperation {
WRAP,
UNWRAP
}
private SSLEngineResult doOp(SSLOperation sslOp, MiniSSLBuffer src, MiniSSLBuffer dst) throws SSLException {
SSLEngineResult res = null;
boolean retryOp = true;
while (retryOp) {
switch (sslOp) {
case WRAP:
res = engine.wrap(src.getRawBuffer(), dst.getRawBuffer());
break;
case UNWRAP:
res = engine.unwrap(src.getRawBuffer(), dst.getRawBuffer());
break;
default:
throw new IllegalStateException("Unknown SSLOperation: " + sslOp);
}
switch (res.getStatus()) {
case BUFFER_OVERFLOW:
log("SSLOp#doRun(): overflow");
log("SSLOp#doRun(): dst data at overflow: " + dst);
// increase the buffer size to accommodate the overflowing data
int newSize = Math.max(engine.getSession().getPacketBufferSize(), engine.getSession().getApplicationBufferSize());
dst.resize(newSize + dst.position());
// retry the operation
retryOp = true;
break;
case BUFFER_UNDERFLOW:
log("SSLOp#doRun(): underflow");
log("SSLOp#doRun(): src data at underflow: " + src);
// need to wait for more data to come in before we retry
retryOp = false;
break;
default:
// other cases are OK and CLOSED. We're done here.
retryOp = false;
}
}
// after each op, run any delegated tasks if needed
if(engine.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
Runnable runnable;
while ((runnable = engine.getDelegatedTask()) != null) {
runnable.run();
}
}
return res;
}
@JRubyMethod
public IRubyObject read() throws Exception {
try {
inboundNetData.flip();
if(!inboundNetData.hasRemaining()) {
return getRuntime().getNil();
}
log("read(): inboundNetData prepped for read: " + inboundNetData);
MiniSSLBuffer inboundAppData = new MiniSSLBuffer(engine.getSession().getApplicationBufferSize());
SSLEngineResult res = doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);
log("read(): after initial unwrap", engine, res);
log("read(): Net Data post unwrap: " + inboundNetData);
HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
boolean done = false;
while (!done) {
switch (handshakeStatus) {
case NEED_WRAP:
res = doOp(SSLOperation.WRAP, inboundAppData, outboundNetData);
log("read(): after handshake wrap", engine, res);
break;
case NEED_UNWRAP:
res = doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);
log("read(): after handshake unwrap", engine, res);
if (res.getStatus() == Status.BUFFER_UNDERFLOW) {
// need more data before we can shake more hands
done = true;
}
break;
default:
done = true;
}
handshakeStatus = engine.getHandshakeStatus();
}
if (inboundNetData.hasRemaining()) {
log("Net Data post pre-compact: " + inboundNetData);
inboundNetData.compact();
log("Net Data post post-compact: " + inboundNetData);
} else {
log("Net Data post pre-reset: " + inboundNetData);
inboundNetData.clear();
log("Net Data post post-reset: " + inboundNetData);
}
ByteList appDataByteList = inboundAppData.asByteList();
if (appDataByteList == null) {
return getRuntime().getNil();
}
RubyString str = getRuntime().newString("");
str.setValue(appDataByteList);
logPlain("\n");
log("read(): begin dump of request data >>>>\n");
if (str.asJavaString().getBytes().length < 1000) {
logPlain(str.asJavaString() + "\n");
}
logPlain("Num bytes: " + str.asJavaString().getBytes().length + "\n");
log("read(): end dump of request data <<<<\n");
return str;
} catch (Exception e) {
if (DEBUG) {
e.printStackTrace();
}
throw getRuntime().newEOFError(e.getMessage());
}
}
private static void log(String str, SSLEngine engine, SSLEngineResult result) {
if (DEBUG) {
log(str + " " + result.getStatus() + "/" + engine.getHandshakeStatus() +
"---bytes consumed: " + result.bytesConsumed() +
", bytes produced: " + result.bytesProduced());
}
}
private static void log(String str) {
if (DEBUG) {
System.out.println("MiniSSL.java: " + str);
}
}
private static void logPlain(String str) {
if (DEBUG) {
System.out.println(str);
}
}
@JRubyMethod
public IRubyObject write(IRubyObject arg) {
try {
log("write(): begin dump of response data >>>>\n");
logPlain("\n");
if (arg.asJavaString().getBytes().length < 1000) {
logPlain(arg.asJavaString() + "\n");
}
logPlain("Num bytes: " + arg.asJavaString().getBytes().length + "\n");
log("write(): end dump of response data <<<<\n");
byte[] bls = arg.convertToString().getBytes();
outboundAppData = new MiniSSLBuffer(bls);
return getRuntime().newFixnum(bls.length);
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
@JRubyMethod
public IRubyObject extract() throws SSLException {
try {
ByteList dataByteList = outboundNetData.asByteList();
if (dataByteList != null) {
RubyString str = getRuntime().newString("");
str.setValue(dataByteList);
return str;
}
if (!outboundAppData.hasRemaining()) {
return getRuntime().getNil();
}
outboundNetData.clear();
SSLEngineResult res = doOp(SSLOperation.WRAP, outboundAppData, outboundNetData);
log("extract(): bytes consumed: " + res.bytesConsumed() + "\n");
log("extract(): bytes produced: " + res.bytesProduced() + "\n");
dataByteList = outboundNetData.asByteList();
if (dataByteList == null) {
return getRuntime().getNil();
}
RubyString str = getRuntime().newString("");
str.setValue(dataByteList);
log("extract(): " + dataByteList.getRealSize() + " encrypted bytes for response");
return str;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
@JRubyMethod
public IRubyObject peercert() {
return getRuntime().getNil();
}
}