puma/ext/puma_http11/org/jruby/puma/MiniSSL.java

509 lines
17 KiB
Java

package org.jruby.puma;
import org.jruby.Ruby;
import org.jruby.RubyArray;
import org.jruby.RubyClass;
import org.jruby.RubyModule;
import org.jruby.RubyObject;
import org.jruby.RubyString;
import org.jruby.anno.JRubyMethod;
import org.jruby.exceptions.RaiseException;
import org.jruby.javasupport.JavaEmbedUtils;
import org.jruby.runtime.Block;
import org.jruby.runtime.ObjectAllocator;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ByteList;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.X509TrustManager;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Map;
import java.util.function.Supplier;
import static javax.net.ssl.SSLEngineResult.Status;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus;
public class MiniSSL extends RubyObject { // MiniSSL::Engine
private static ObjectAllocator ALLOCATOR = new ObjectAllocator() {
public IRubyObject allocate(Ruby runtime, RubyClass klass) {
return new MiniSSL(runtime, klass);
}
};
public static void createMiniSSL(Ruby runtime) {
RubyModule mPuma = runtime.defineModule("Puma");
RubyModule ssl = mPuma.defineModuleUnder("MiniSSL");
// Puma::MiniSSL::SSLError
ssl.defineClassUnder("SSLError", runtime.getStandardError(), runtime.getStandardError().getAllocator());
RubyClass eng = ssl.defineClassUnder("Engine", runtime.getObject(), ALLOCATOR);
eng.defineAnnotatedMethods(MiniSSL.class);
}
/**
* Fairly transparent wrapper around {@link java.nio.ByteBuffer} which adds the enhancements we need
*/
private static class MiniSSLBuffer {
ByteBuffer buffer;
private MiniSSLBuffer(int capacity) { buffer = ByteBuffer.allocate(capacity); }
private MiniSSLBuffer(byte[] initialContents) { buffer = ByteBuffer.wrap(initialContents); }
public void clear() { buffer.clear(); }
public void compact() { buffer.compact(); }
public void flip() { ((Buffer) buffer).flip(); }
public boolean hasRemaining() { return buffer.hasRemaining(); }
public int position() { return buffer.position(); }
public ByteBuffer getRawBuffer() {
return buffer;
}
/**
* Writes bytes to the buffer after ensuring there's room
*/
private void put(byte[] bytes, final int offset, final int length) {
if (buffer.remaining() < length) {
resize(buffer.limit() + length);
}
buffer.put(bytes, offset, length);
}
/**
* Ensures that newCapacity bytes can be written to this buffer, only re-allocating if necessary
*/
public void resize(int newCapacity) {
if (newCapacity > buffer.capacity()) {
ByteBuffer dstTmp = ByteBuffer.allocate(newCapacity);
flip();
dstTmp.put(buffer);
buffer = dstTmp;
} else {
buffer.limit(newCapacity);
}
}
/**
* Drains the buffer to a ByteList, or returns null for an empty buffer
*/
public ByteList asByteList() {
flip();
if (!buffer.hasRemaining()) {
buffer.clear();
return null;
}
byte[] bss = new byte[buffer.limit()];
buffer.get(bss);
buffer.clear();
return new ByteList(bss, false);
}
@Override
public String toString() { return buffer.toString(); }
}
private SSLEngine engine;
private boolean closed;
private boolean handshake;
private MiniSSLBuffer inboundNetData;
private MiniSSLBuffer outboundAppData;
private MiniSSLBuffer outboundNetData;
public MiniSSL(Ruby runtime, RubyClass klass) {
super(runtime, klass);
}
private static Map<String, KeyManagerFactory> keyManagerFactoryMap = new ConcurrentHashMap<String, KeyManagerFactory>();
private static Map<String, TrustManagerFactory> trustManagerFactoryMap = new ConcurrentHashMap<String, TrustManagerFactory>();
@JRubyMethod(meta = true) // Engine.server
public static synchronized IRubyObject server(ThreadContext context, IRubyObject recv, IRubyObject miniSSLContext)
throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException {
// Create the KeyManagerFactory and TrustManagerFactory for this server
String keystoreFile = asStringValue(miniSSLContext.callMethod(context, "keystore"), null);
char[] keystorePass = asStringValue(miniSSLContext.callMethod(context, "keystore_pass"), null).toCharArray();
String keystoreType = asStringValue(miniSSLContext.callMethod(context, "keystore_type"), KeyStore::getDefaultType);
String truststoreFile;
char[] truststorePass;
String truststoreType;
IRubyObject truststore = miniSSLContext.callMethod(context, "truststore");
if (truststore.isNil()) {
truststoreFile = keystoreFile;
truststorePass = keystorePass;
truststoreType = keystoreType;
} else if (!isDefaultSymbol(context, truststore)) {
truststoreFile = truststore.convertToString().asJavaString();
IRubyObject pass = miniSSLContext.callMethod(context, "truststore_pass");
if (pass.isNil()) {
truststorePass = null;
} else {
truststorePass = asStringValue(pass, null).toCharArray();
}
truststoreType = asStringValue(miniSSLContext.callMethod(context, "truststore_type"), KeyStore::getDefaultType);
} else { // self.truststore = :default
truststoreFile = null;
truststorePass = null;
truststoreType = null;
}
KeyStore ks = KeyStore.getInstance(keystoreType);
InputStream is = new FileInputStream(keystoreFile);
try {
ks.load(is, keystorePass);
} finally {
is.close();
}
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, keystorePass);
keyManagerFactoryMap.put(keystoreFile, kmf);
if (truststoreFile != null) {
KeyStore ts = KeyStore.getInstance(truststoreType);
is = new FileInputStream(truststoreFile);
try {
ts.load(is, truststorePass);
} finally {
is.close();
}
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(ts);
trustManagerFactoryMap.put(truststoreFile, tmf);
}
RubyClass klass = (RubyClass) recv;
return klass.newInstance(context, miniSSLContext, Block.NULL_BLOCK);
}
private static String asStringValue(IRubyObject value, Supplier<String> defaultValue) {
if (defaultValue != null && value.isNil()) return defaultValue.get();
return value.convertToString().asJavaString();
}
private static boolean isDefaultSymbol(ThreadContext context, IRubyObject truststore) {
return context.runtime.newSymbol("default").equals(truststore);
}
@JRubyMethod
public IRubyObject initialize(ThreadContext context, IRubyObject miniSSLContext)
throws KeyStoreException, NoSuchAlgorithmException, KeyManagementException {
String keystoreFile = miniSSLContext.callMethod(context, "keystore").convertToString().asJavaString();
KeyManagerFactory kmf = keyManagerFactoryMap.get(keystoreFile);
IRubyObject truststore = miniSSLContext.callMethod(context, "truststore");
String truststoreFile = isDefaultSymbol(context, truststore) ? "" : asStringValue(truststore, () -> keystoreFile);
TrustManagerFactory tmf = trustManagerFactoryMap.get(truststoreFile); // null if self.truststore = :default
if (kmf == null) {
throw new KeyStoreException("Could not find KeyManagerFactory for keystore: " + keystoreFile + " truststore: " + truststoreFile);
}
SSLContext sslCtx = SSLContext.getInstance("TLS");
sslCtx.init(kmf.getKeyManagers(), getTrustManagers(tmf), null);
closed = false;
handshake = false;
engine = sslCtx.createSSLEngine();
String[] enabledProtocols;
IRubyObject protocols = miniSSLContext.callMethod(context, "protocols");
if (protocols.isNil()) {
if (miniSSLContext.callMethod(context, "no_tlsv1").isTrue()) {
enabledProtocols = new String[] { "TLSv1.1", "TLSv1.2", "TLSv1.3" };
} else {
enabledProtocols = new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3" };
}
if (miniSSLContext.callMethod(context, "no_tlsv1_1").isTrue()) {
enabledProtocols = new String[] { "TLSv1.2", "TLSv1.3" };
}
} else if (protocols instanceof RubyArray) {
enabledProtocols = (String[]) ((RubyArray) protocols).toArray(new String[0]);
} else {
throw context.runtime.newTypeError(protocols, context.runtime.getArray());
}
engine.setEnabledProtocols(enabledProtocols);
engine.setUseClientMode(false);
long verify_mode = miniSSLContext.callMethod(context, "verify_mode").convertToInteger("to_i").getLongValue();
if ((verify_mode & 0x1) != 0) { // 'peer'
engine.setWantClientAuth(true);
}
if ((verify_mode & 0x2) != 0) { // 'force_peer'
engine.setNeedClientAuth(true);
}
IRubyObject cipher_suites = miniSSLContext.callMethod(context, "cipher_suites");
if (cipher_suites instanceof RubyArray) {
engine.setEnabledCipherSuites((String[]) ((RubyArray) cipher_suites).toArray(new String[0]));
} else if (!cipher_suites.isNil()) {
throw context.runtime.newTypeError(cipher_suites, context.runtime.getArray());
}
SSLSession session = engine.getSession();
inboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());
outboundAppData = new MiniSSLBuffer(session.getApplicationBufferSize());
outboundAppData.flip();
outboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());
return this;
}
private TrustManager[] getTrustManagers(TrustManagerFactory factory) {
if (factory == null) return null; // use JDK trust defaults
final TrustManager[] tms = factory.getTrustManagers();
if (tms != null) {
for (int i=0; i<tms.length; i++) {
final TrustManager tm = tms[i];
if (tm instanceof X509TrustManager) {
tms[i] = new TrustManagerWrapper((X509TrustManager) tm);
}
}
}
return tms;
}
private volatile transient X509Certificate lastCheckedCert0;
private class TrustManagerWrapper implements X509TrustManager {
private final X509TrustManager delegate;
TrustManagerWrapper(X509TrustManager delegate) {
this.delegate = delegate;
}
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
lastCheckedCert0 = chain.length > 0 ? chain[0] : null;
delegate.checkClientTrusted(chain, authType);
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
delegate.checkServerTrusted(chain, authType);
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return delegate.getAcceptedIssuers();
}
}
@JRubyMethod
public IRubyObject inject(IRubyObject arg) {
ByteList bytes = arg.convertToString().getByteList();
inboundNetData.put(bytes.unsafeBytes(), bytes.getBegin(), bytes.getRealSize());
return this;
}
private enum SSLOperation {
WRAP,
UNWRAP
}
private SSLEngineResult doOp(SSLOperation sslOp, MiniSSLBuffer src, MiniSSLBuffer dst) throws SSLException {
SSLEngineResult res = null;
boolean retryOp = true;
while (retryOp) {
switch (sslOp) {
case WRAP:
res = engine.wrap(src.getRawBuffer(), dst.getRawBuffer());
break;
case UNWRAP:
res = engine.unwrap(src.getRawBuffer(), dst.getRawBuffer());
break;
default:
throw new AssertionError("Unknown SSLOperation: " + sslOp);
}
switch (res.getStatus()) {
case BUFFER_OVERFLOW:
// increase the buffer size to accommodate the overflowing data
int newSize = Math.max(engine.getSession().getPacketBufferSize(), engine.getSession().getApplicationBufferSize());
dst.resize(newSize + dst.position());
// retry the operation
retryOp = true;
break;
case BUFFER_UNDERFLOW:
// need to wait for more data to come in before we retry
retryOp = false;
break;
case CLOSED:
closed = true;
retryOp = false;
break;
default:
// other case is OK. We're done here.
retryOp = false;
}
if (res.getHandshakeStatus() == HandshakeStatus.FINISHED) {
handshake = true;
}
}
return res;
}
@JRubyMethod
public IRubyObject read() {
try {
inboundNetData.flip();
if(!inboundNetData.hasRemaining()) {
return getRuntime().getNil();
}
MiniSSLBuffer inboundAppData = new MiniSSLBuffer(engine.getSession().getApplicationBufferSize());
doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);
HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
boolean done = false;
while (!done) {
SSLEngineResult res;
switch (handshakeStatus) {
case NEED_WRAP:
res = doOp(SSLOperation.WRAP, inboundAppData, outboundNetData);
handshakeStatus = res.getHandshakeStatus();
break;
case NEED_UNWRAP:
res = doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);
if (res.getStatus() == Status.BUFFER_UNDERFLOW) {
// need more data before we can shake more hands
done = true;
}
handshakeStatus = res.getHandshakeStatus();
break;
case NEED_TASK:
Runnable runnable;
while ((runnable = engine.getDelegatedTask()) != null) {
runnable.run();
}
handshakeStatus = engine.getHandshakeStatus();
break;
default:
done = true;
}
}
if (inboundNetData.hasRemaining()) {
inboundNetData.compact();
} else {
inboundNetData.clear();
}
ByteList appDataByteList = inboundAppData.asByteList();
if (appDataByteList == null) {
return getRuntime().getNil();
}
return RubyString.newString(getRuntime(), appDataByteList);
} catch (SSLException e) {
throw newSSLError(getRuntime(), e);
}
}
@JRubyMethod
public IRubyObject write(IRubyObject arg) {
byte[] bls = arg.convertToString().getBytes();
outboundAppData = new MiniSSLBuffer(bls);
return getRuntime().newFixnum(bls.length);
}
@JRubyMethod
public IRubyObject extract(ThreadContext context) {
try {
ByteList dataByteList = outboundNetData.asByteList();
if (dataByteList != null) {
return RubyString.newString(context.runtime, dataByteList);
}
if (!outboundAppData.hasRemaining()) {
return context.nil;
}
outboundNetData.clear();
doOp(SSLOperation.WRAP, outboundAppData, outboundNetData);
dataByteList = outboundNetData.asByteList();
if (dataByteList == null) {
return context.nil;
}
return RubyString.newString(context.runtime, dataByteList);
} catch (SSLException e) {
throw newSSLError(getRuntime(), e);
}
}
@JRubyMethod
public IRubyObject peercert(ThreadContext context) throws CertificateEncodingException {
Certificate peerCert;
try {
peerCert = engine.getSession().getPeerCertificates()[0];
} catch (SSLPeerUnverifiedException e) {
peerCert = lastCheckedCert0; // null if trust check did not happen
}
return peerCert == null ? context.nil : JavaEmbedUtils.javaToRuby(context.runtime, peerCert.getEncoded());
}
@JRubyMethod(name = "init?")
public IRubyObject isInit(ThreadContext context) {
return handshake ? getRuntime().getFalse() : getRuntime().getTrue();
}
@JRubyMethod
public IRubyObject shutdown() {
if (closed || engine.isInboundDone() && engine.isOutboundDone()) {
if (engine.isOutboundDone()) {
engine.closeOutbound();
}
return getRuntime().getTrue();
} else {
return getRuntime().getFalse();
}
}
private static RubyClass getSSLError(Ruby runtime) {
return (RubyClass) ((RubyModule) runtime.getModule("Puma").getConstantAt("MiniSSL")).getConstantAt("SSLError");
}
private static RaiseException newSSLError(Ruby runtime, SSLException cause) {
return newError(runtime, getSSLError(runtime), cause.toString(), cause);
}
private static RaiseException newError(Ruby runtime, RubyClass errorClass, String message, Throwable cause) {
RaiseException ex = new RaiseException(runtime, errorClass, message, true);
ex.initCause(cause);
return ex;
}
}