package org.glassfish.grizzly.ssl;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import org.glassfish.grizzly.Buffer;
import org.glassfish.grizzly.CompletionHandler;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.Context;
import org.glassfish.grizzly.FileTransfer;
import org.glassfish.grizzly.Grizzly;
import org.glassfish.grizzly.GrizzlyFuture;
import org.glassfish.grizzly.IOEvent;
import org.glassfish.grizzly.IOEventLifeCycleListener.Adapter;
import org.glassfish.grizzly.ProcessorExecutor;
import org.glassfish.grizzly.ReadResult;
import org.glassfish.grizzly.Transport;
import org.glassfish.grizzly.asyncqueue.MessageCloner;
import org.glassfish.grizzly.filterchain.BaseFilter;
import org.glassfish.grizzly.filterchain.FilterChain;
import org.glassfish.grizzly.filterchain.FilterChainContext;
import org.glassfish.grizzly.filterchain.FilterChainContext.Operation;
import org.glassfish.grizzly.filterchain.FilterChainContext.TransportContext;
import org.glassfish.grizzly.filterchain.FilterChainEvent;
import org.glassfish.grizzly.filterchain.NextAction;
import org.glassfish.grizzly.filterchain.TransportFilter;
import org.glassfish.grizzly.impl.FutureImpl;
import org.glassfish.grizzly.memory.Buffers;
import org.glassfish.grizzly.memory.CompositeBuffer;
import org.glassfish.grizzly.memory.MemoryManager;
import org.glassfish.grizzly.ssl.SSLConnectionContext.Allocator;
import org.glassfish.grizzly.ssl.SSLConnectionContext.SslResult;
import org.glassfish.grizzly.utils.Futures;
import static org.glassfish.grizzly.ssl.SSLUtils.*;
public class SSLBaseFilter extends BaseFilter {
private static final Logger LOGGER = Grizzly.logger(SSLBaseFilter.class);
protected static final MessageCloner<Buffer> COPY_CLONER = new OnWriteCopyCloner();
private static final Allocator MM_ALLOCATOR = new Allocator() {
@Override
@SuppressWarnings("unchecked")
public Buffer grow(final SSLConnectionContext sslCtx,
final Buffer oldBuffer, final int newSize) {
final MemoryManager mm = sslCtx.getConnection().getMemoryManager();
return oldBuffer == null ?
mm.allocate(newSize) :
mm.reallocate(oldBuffer, newSize);
}
};
private static final Allocator OUTPUT_BUFFER_ALLOCATOR =
new Allocator() {
@Override
public Buffer grow(final SSLConnectionContext sslCtx,
final Buffer oldBuffer, final int newSize) {
return allocateOutputBuffer(newSize);
}
};
private final SSLEngineConfigurator ;
private final boolean renegotiateOnClientAuthWant;
private volatile boolean renegotiationDisabled;
protected final Set<HandshakeListener> handshakeListeners =
Collections.newSetFromMap(new ConcurrentHashMap<>(2));
private long handshakeTimeoutMillis = -1;
private SSLTransportFilterWrapper optimizedTransportFilter;
public SSLBaseFilter() {
this(null);
}
public SSLBaseFilter(SSLEngineConfigurator serverSSLEngineConfigurator) {
this(serverSSLEngineConfigurator, true);
}
public SSLBaseFilter(SSLEngineConfigurator serverSSLEngineConfigurator,
boolean renegotiateOnClientAuthWant) {
this.renegotiateOnClientAuthWant = renegotiateOnClientAuthWant;
this.serverSSLEngineConfigurator =
((serverSSLEngineConfigurator != null)
? serverSSLEngineConfigurator
: new SSLEngineConfigurator(
SSLContextConfigurator.DEFAULT_CONFIG.createSSLContext(true),
false,
false,
false));
}
public boolean isRenegotiateOnClientAuthWant() {
return renegotiateOnClientAuthWant;
}
public SSLEngineConfigurator () {
return serverSSLEngineConfigurator;
}
public void addHandshakeListener(final HandshakeListener listener) {
handshakeListeners.add(listener);
}
@SuppressWarnings("unused")
public void removeHandshakeListener(final HandshakeListener listener) {
handshakeListeners.remove(listener);
}
@SuppressWarnings("unused")
public long getHandshakeTimeout(final TimeUnit timeUnit) {
if (handshakeTimeoutMillis < 0) {
return -1;
}
return timeUnit.convert(handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
}
@SuppressWarnings("unused")
public void setHandshakeTimeout(final long handshakeTimeout,
final TimeUnit timeUnit) {
if (handshakeTimeout < 0) {
handshakeTimeoutMillis = -1;
} else {
this.handshakeTimeoutMillis =
TimeUnit.MILLISECONDS.convert(handshakeTimeout, timeUnit);
}
}
public void setRenegotiationDisabled(boolean renegotiationDisabled) {
this.renegotiationDisabled = renegotiationDisabled;
}
protected SSLTransportFilterWrapper getOptimizedTransportFilter(
final TransportFilter childFilter) {
if (optimizedTransportFilter == null ||
optimizedTransportFilter.wrappedFilter != childFilter) {
optimizedTransportFilter = createOptimizedTransportFilter(childFilter);
}
return optimizedTransportFilter;
}
protected SSLTransportFilterWrapper createOptimizedTransportFilter(
final TransportFilter childFilter) {
return new SSLTransportFilterWrapper(childFilter, this);
}
@Override
public void onRemoved(final FilterChain filterChain) {
if (optimizedTransportFilter != null) {
final int sslTransportFilterIdx = filterChain.indexOf(optimizedTransportFilter);
if (sslTransportFilterIdx >= 0) {
SSLTransportFilterWrapper wrapper =
(SSLTransportFilterWrapper) filterChain.get(sslTransportFilterIdx);
filterChain.set(sslTransportFilterIdx, wrapper.wrappedFilter);
}
}
}
@Override
public void onAdded(FilterChain filterChain) {
final int sslTransportFilterIdx =
filterChain.indexOfType(SSLTransportFilterWrapper.class);
if (sslTransportFilterIdx == -1) {
final int transportFilterIdx =
filterChain.indexOfType(TransportFilter.class);
if (transportFilterIdx >= 0) {
filterChain.set(transportFilterIdx,
getOptimizedTransportFilter(
(TransportFilter) filterChain.get(transportFilterIdx)));
}
}
}
@Override
public NextAction handleEvent(final FilterChainContext ctx,
final FilterChainEvent event)
throws IOException {
if (event.type() == CertificateEvent.TYPE) {
final CertificateEvent ce = (CertificateEvent) event;
try {
return ctx.getSuspendAction();
} finally {
getPeerCertificateChain(obtainSslConnectionContext(ctx.getConnection()),
ctx,
ce.needClientAuth,
ce.certsFuture);
}
}
return ctx.getInvokeAction();
}
@Override
public NextAction handleRead(final FilterChainContext ctx)
throws IOException {
final Connection connection = ctx.getConnection();
final SSLConnectionContext sslCtx = obtainSslConnectionContext(connection);
SSLEngine sslEngine = sslCtx.getSslEngine();
if (sslEngine != null && !isHandshaking(sslEngine)) {
return unwrapAll(ctx, sslCtx);
} else {
if (sslEngine == null) {
sslEngine = serverSSLEngineConfigurator.createSSLEngine();
sslEngine.beginHandshake();
sslCtx.configure(sslEngine);
notifyHandshakeStart(connection);
}
final Buffer buffer;
buffer = ((handshakeTimeoutMillis >= 0)
? doHandshakeSync(sslCtx,
ctx,
(Buffer) ctx.getMessage(),
handshakeTimeoutMillis)
: makeInputRemainder(sslCtx,
ctx,
doHandshakeStep(sslCtx,
ctx,
(Buffer) ctx.getMessage())));
final boolean hasRemaining = buffer != null && buffer.hasRemaining();
final boolean isHandshaking = isHandshaking(sslEngine);
if (!isHandshaking) {
notifyHandshakeComplete(connection, sslEngine);
final FilterChain connectionFilterChain = sslCtx.getNewConnectionFilterChain();
sslCtx.setNewConnectionFilterChain(null);
if (connectionFilterChain != null) {
if (LOGGER.isLoggable(Level.FINE)) {
LOGGER.log(Level.FINE, "Applying new FilterChain after"
+ "SSLHandshake. Connection={0} filterchain={1}",
new Object[]{connection, connectionFilterChain});
}
connection.setProcessor(connectionFilterChain);
if (hasRemaining) {
NextAction suspendAction = ctx.getSuspendAction();
ctx.setMessage(buffer);
ctx.suspend();
final FilterChainContext newContext =
obtainProtocolChainContext(ctx, connectionFilterChain);
ProcessorExecutor.execute(newContext.getInternalContext());
return suspendAction;
} else {
return ctx.getStopAction();
}
}
if (hasRemaining) {
ctx.setMessage(buffer);
return unwrapAll(ctx, sslCtx);
}
}
return ctx.getStopAction(buffer);
}
}
@SuppressWarnings("unchecked")
@Override
public NextAction handleWrite(final FilterChainContext ctx) throws IOException {
if (ctx.getMessage() instanceof FileTransfer) {
throw new IllegalStateException("TLS operations not supported with SendFile messages");
}
final Connection connection = ctx.getConnection();
synchronized(connection) {
final Buffer output =
wrapAll(ctx, obtainSslConnectionContext(connection));
final TransportContext transportContext =
ctx.getTransportContext();
ctx.write(null, output,
transportContext.getCompletionHandler(),
transportContext.getPushBackHandler(),
COPY_CLONER,
transportContext.isBlocking());
return ctx.getStopAction();
}
}
protected NextAction unwrapAll(final FilterChainContext ctx,
final SSLConnectionContext sslCtx) throws SSLException {
Buffer input = ctx.getMessage();
Buffer output = null;
boolean isClosed = false;
_outter:
do {
final int len = getSSLPacketSize(input);
if (len == -1 || input.remaining() < len) {
break;
}
final SslResult result =
sslCtx.unwrap(len, input, output, MM_ALLOCATOR);
output = result.getOutput();
if (result.isError()) {
output.dispose();
throw result.getError();
}
if (isHandshaking(sslCtx.getSslEngine())) {
if (result.getSslEngineResult().getStatus() != Status.CLOSED) {
input = rehandshake(ctx, sslCtx);
} else {
input = silentRehandshake(ctx, sslCtx);
isClosed = true;
}
if (input == null) {
break;
}
}
switch (result.getSslEngineResult().getStatus()) {
case OK:
if (input.hasRemaining()) {
break;
}
break _outter;
case CLOSED:
isClosed = true;
break _outter;
default:
throw new IllegalStateException("Unexpected status: " +
result.getSslEngineResult().getStatus());
}
} while (true);
if (output != null) {
output.trim();
if (output.hasRemaining() || isClosed) {
ctx.setMessage(output);
return ctx.getInvokeAction(makeInputRemainder(sslCtx, ctx, input));
}
}
return ctx.getStopAction(makeInputRemainder(sslCtx, ctx, input));
}
@SuppressWarnings("MethodMayBeStatic")
protected Buffer wrapAll(final FilterChainContext ctx,
final SSLConnectionContext sslCtx) throws SSLException {
final Buffer input = ctx.getMessage();
final Buffer output = sslCtx.wrapAll(input, OUTPUT_BUFFER_ALLOCATOR);
input.tryDispose();
return output;
}
protected Buffer doHandshakeSync(final SSLConnectionContext sslCtx,
final FilterChainContext ctx,
Buffer inputBuffer,
final long timeoutMillis) throws IOException {
final Connection connection = ctx.getConnection();
final SSLEngine sslEngine = sslCtx.getSslEngine();
final Buffer tmpAppBuffer = allocateOutputBuffer(sslCtx.getAppBufferSize());
final long oldReadTimeout = connection.getReadTimeout(TimeUnit.MILLISECONDS);
try {
connection.setReadTimeout(timeoutMillis, TimeUnit.MILLISECONDS);
inputBuffer = makeInputRemainder(sslCtx, ctx,
doHandshakeStep(sslCtx, ctx, inputBuffer, tmpAppBuffer));
while (isHandshaking(sslEngine)) {
final ReadResult rr = ctx.read();
final Buffer newBuf = (Buffer) rr.getMessage();
inputBuffer = Buffers.appendBuffers(ctx.getMemoryManager(),
inputBuffer, newBuf);
inputBuffer = makeInputRemainder(sslCtx, ctx,
doHandshakeStep(sslCtx, ctx, inputBuffer, tmpAppBuffer));
}
} finally {
tmpAppBuffer.dispose();
connection.setReadTimeout(oldReadTimeout, TimeUnit.MILLISECONDS);
}
return inputBuffer;
}
protected Buffer doHandshakeStep(final SSLConnectionContext sslCtx,
final FilterChainContext ctx,
Buffer inputBuffer) throws IOException {
return doHandshakeStep(sslCtx, ctx, inputBuffer, null);
}
protected Buffer doHandshakeStep(final SSLConnectionContext sslCtx,
final FilterChainContext ctx,
Buffer inputBuffer,
final Buffer tmpAppBuffer0)
throws IOException {
final SSLEngine sslEngine = sslCtx.getSslEngine();
final Connection connection = ctx.getConnection();
final boolean isLoggingFinest = LOGGER.isLoggable(Level.FINEST);
Buffer tmpInputToDispose = null;
Buffer tmpNetBuffer = null;
Buffer tmpAppBuffer = tmpAppBuffer0;
try {
HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
_exitWhile:
while (true) {
if (isLoggingFinest) {
LOGGER.log(Level.FINEST, "Loop Engine: {0} handshakeStatus={1}",
new Object[]{sslEngine, sslEngine.getHandshakeStatus()});
}
switch (handshakeStatus) {
case NEED_UNWRAP: {
if (isLoggingFinest) {
LOGGER.log(Level.FINEST, "NEED_UNWRAP Engine: {0}", sslEngine);
}
if (inputBuffer == null || !inputBuffer.hasRemaining()) {
break _exitWhile;
}
final int expectedLength = getSSLPacketSize(inputBuffer);
if (expectedLength == -1
|| inputBuffer.remaining() < expectedLength) {
break _exitWhile;
}
if (tmpAppBuffer == null) {
tmpAppBuffer = allocateOutputBuffer(sslCtx.getAppBufferSize());
}
final SSLEngineResult sslEngineResult =
handshakeUnwrap(expectedLength, sslCtx, inputBuffer, tmpAppBuffer);
if (!inputBuffer.hasRemaining()) {
tmpInputToDispose = inputBuffer;
inputBuffer = null;
}
final Status status = sslEngineResult.getStatus();
if (status == Status.BUFFER_UNDERFLOW ||
status == Status.BUFFER_OVERFLOW) {
throw new SSLException("SSL unwrap error: " + status);
}
handshakeStatus = sslEngine.getHandshakeStatus();
break;
}
case NEED_WRAP: {
if (isLoggingFinest) {
LOGGER.log(Level.FINEST, "NEED_WRAP Engine: {0}", sslEngine);
}
tmpNetBuffer = handshakeWrap(
connection, sslCtx, tmpNetBuffer);
handshakeStatus = sslEngine.getHandshakeStatus();
break;
}
case NEED_TASK: {
if (isLoggingFinest) {
LOGGER.log(Level.FINEST, "NEED_TASK Engine: {0}", sslEngine);
}
executeDelegatedTask(sslEngine);
handshakeStatus = sslEngine.getHandshakeStatus();
break;
}
case FINISHED:
case NOT_HANDSHAKING: {
break _exitWhile;
}
}
if (handshakeStatus == HandshakeStatus.FINISHED) {
break;
}
}
} catch (IOException ioe) {
notifyHandshakeFailed(connection, ioe);
throw ioe;
} finally {
if (tmpAppBuffer0 == null && tmpAppBuffer != null) {
tmpAppBuffer.dispose();
}
if (tmpInputToDispose != null) {
tmpInputToDispose.tryDispose();
inputBuffer = null;
} else if (inputBuffer != null) {
inputBuffer.shrink();
}
if (tmpNetBuffer != null) {
if (inputBuffer != null) {
inputBuffer = makeInputRemainder(sslCtx, ctx, inputBuffer);
}
ctx.write(tmpNetBuffer);
}
}
return inputBuffer;
}
protected void renegotiate(final SSLConnectionContext sslCtx,
final FilterChainContext context)
throws IOException {
if (renegotiationDisabled) {
return;
}
final SSLEngine sslEngine = sslCtx.getSslEngine();
if (sslEngine.getWantClientAuth() && !renegotiateOnClientAuthWant) {
return;
}
final boolean authConfigured =
(sslEngine.getWantClientAuth()
|| sslEngine.getNeedClientAuth());
if (!authConfigured) {
sslEngine.setNeedClientAuth(true);
}
sslEngine.getSession().invalidate();
try {
sslEngine.beginHandshake();
} catch (SSLHandshakeException e) {
if (e.toString().toLowerCase().contains("insecure renegotiation")) {
if (LOGGER.isLoggable(Level.SEVERE)) {
LOGGER.severe("Secure SSL/TLS renegotiation is not "
+ "supported by the peer. This is most likely due"
+ " to the peer using an older SSL/TLS "
+ "implementation that does not implement RFC 5746.");
}
}
throw e;
}
try {
rehandshake(context, sslCtx);
} finally {
if (!authConfigured) {
sslEngine.setNeedClientAuth(false);
}
}
}
private Buffer silentRehandshake(final FilterChainContext context,
final SSLConnectionContext sslCtx) throws SSLException {
try {
return doHandshakeSync(
sslCtx, context, null, handshakeTimeoutMillis);
} catch (Throwable t) {
if (LOGGER.isLoggable(Level.FINE)) {
LOGGER.log(Level.FINE, "Error during graceful ssl connection close", t);
}
if (t instanceof SSLException) {
throw (SSLException) t;
}
throw new SSLException("Error during re-handshaking", t);
}
}
private Buffer rehandshake(final FilterChainContext context,
final SSLConnectionContext sslCtx) throws SSLException {
final Connection c = context.getConnection();
notifyHandshakeStart(c);
try {
final Buffer buffer = doHandshakeSync(
sslCtx, context, null, handshakeTimeoutMillis);
notifyHandshakeComplete(c, sslCtx.getSslEngine());
return buffer;
} catch (Throwable t) {
notifyHandshakeFailed(c, t);
if (LOGGER.isLoggable(Level.FINE)) {
LOGGER.log(Level.FINE, "Error during re-handshaking", t);
}
if (t instanceof SSLException) {
throw (SSLException) t;
}
throw new SSLException("Error during re-handshaking", t);
}
}
protected void getPeerCertificateChain(final SSLConnectionContext sslCtx,
final FilterChainContext context,
final boolean needClientAuth,
final FutureImpl<Object[]> certFuture) {
Certificate[] certs = getPeerCertificates(sslCtx);
if (certs != null) {
certFuture.result(certs);
return;
}
if (needClientAuth) {
final Transport transport = context.getConnection().getTransport();
ExecutorService threadPool = transport.getWorkerThreadPool();
if (threadPool == null) {
threadPool = transport.getKernelThreadPool();
}
threadPool.submit(new Runnable() {
@Override
public void run() {
try {
try {
renegotiate(sslCtx, context);
} catch (IOException ioe) {
certFuture.failure(ioe);
return;
}
Certificate[] certs = getPeerCertificates(sslCtx);
if (certs == null) {
certFuture.result(null);
return;
}
X509Certificate[] x509Certs = extractX509Certs(certs);
if (x509Certs == null || x509Certs.length < 1) {
certFuture.result(null);
return;
}
certFuture.result(x509Certs);
} finally {
context.resume(context.getStopAction());
}
}
});
}
}
protected SSLConnectionContext obtainSslConnectionContext(
final Connection connection) {
SSLConnectionContext sslCtx = SSL_CTX_ATTR.get(connection);
if (sslCtx == null) {
sslCtx = createSslConnectionContext(connection);
SSL_CTX_ATTR.set(connection, sslCtx);
}
return sslCtx;
}
@SuppressWarnings("MethodMayBeStatic")
protected SSLConnectionContext createSslConnectionContext(
final Connection connection) {
return new SSLConnectionContext(connection);
}
private static FilterChainContext obtainProtocolChainContext(
final FilterChainContext ctx,
final FilterChain completeProtocolFilterChain) {
final FilterChainContext newFilterChainContext =
completeProtocolFilterChain.obtainFilterChainContext(
ctx.getConnection(),
ctx.getStartIdx(),
completeProtocolFilterChain.size(),
ctx.getFilterIdx());
newFilterChainContext.setAddressHolder(ctx.getAddressHolder());
newFilterChainContext.setMessage(ctx.getMessage());
newFilterChainContext.getInternalContext().setIoEvent(IOEvent.READ);
newFilterChainContext.getInternalContext().addLifeCycleListener(
new InternalProcessingHandler(ctx));
return newFilterChainContext;
}
private static X509Certificate[] (final Certificate[] certs) {
final X509Certificate[] x509Certs = new X509Certificate[certs.length];
for(int i = 0, len = certs.length; i < len; i++) {
if( certs[i] instanceof X509Certificate ) {
x509Certs[i] = (X509Certificate)certs[i];
} else {
try {
final byte [] buffer = certs[i].getEncoded();
final CertificateFactory cf =
CertificateFactory.getInstance("X.509");
ByteArrayInputStream stream = new ByteArrayInputStream(buffer);
x509Certs[i] = (X509Certificate)
cf.generateCertificate(stream);
} catch(Exception ex) {
LOGGER.log(Level.INFO,
"Error translating cert " + certs[i],
ex);
return null;
}
}
if (LOGGER.isLoggable(Level.FINE)) {
LOGGER.log(Level.FINE, "Cert #{0} = {1}", new Object[] {i, x509Certs[i]});
}
}
return x509Certs;
}
private static Certificate[] getPeerCertificates(final SSLConnectionContext sslCtx) {
try {
return sslCtx.getSslEngine().getSession().getPeerCertificates();
} catch( Throwable t ) {
if (LOGGER.isLoggable(Level.FINE)) {
LOGGER.log(Level.FINE,"Error getting client certs", t);
}
return null;
}
}
protected void notifyHandshakeStart(final Connection connection) {
if (!handshakeListeners.isEmpty()) {
for (final HandshakeListener listener : handshakeListeners) {
listener.onStart(connection);
}
}
}
protected void notifyHandshakeComplete(final Connection<?> connection,
final SSLEngine sslEngine) {
if (!handshakeListeners.isEmpty()) {
for (final HandshakeListener listener : handshakeListeners) {
listener.onComplete(connection);
}
}
}
protected void notifyHandshakeFailed(final Connection connection,
final Throwable t) {
if (!handshakeListeners.isEmpty()) {
for (final HandshakeListener listener : handshakeListeners) {
listener.onFailure(connection, t);
}
}
}
public static class CertificateEvent implements FilterChainEvent {
static final String TYPE = "CERT_EVENT";
final FutureImpl<Object[]> certsFuture;
final boolean needClientAuth;
public CertificateEvent(final boolean needClientAuth) {
this.needClientAuth = needClientAuth;
certsFuture = Futures.createSafeFuture();
}
@Override
public final Object type() {
return TYPE;
}
public GrizzlyFuture<Object[]> trigger(final FilterChainContext ctx) {
ctx.getFilterChain().fireEventDownstream(ctx.getConnection(),
this,
null);
return certsFuture;
}
}
private static class InternalProcessingHandler extends Adapter {
private final FilterChainContext parentContext;
private InternalProcessingHandler(final FilterChainContext parentContext) {
this.parentContext = parentContext;
}
@Override
public void onComplete(final Context context, Object data) throws IOException {
parentContext.resume(parentContext.getStopAction());
}
}
public interface HandshakeListener {
void onStart(Connection connection);
void onComplete(Connection connection);
void onFailure(Connection connection, Throwable t);
}
protected static class SSLTransportFilterWrapper extends TransportFilter {
protected final TransportFilter wrappedFilter;
protected final SSLBaseFilter sslBaseFilter;
public SSLTransportFilterWrapper(final TransportFilter transportFilter,
final SSLBaseFilter sslBaseFilter) {
this.wrappedFilter = transportFilter;
this.sslBaseFilter = sslBaseFilter;
}
@Override
public NextAction handleAccept(FilterChainContext ctx) throws IOException {
return wrappedFilter.handleAccept(ctx);
}
@Override
public NextAction handleConnect(FilterChainContext ctx) throws IOException {
return wrappedFilter.handleConnect(ctx);
}
@Override
public NextAction handleRead(final FilterChainContext ctx) throws IOException {
final Connection connection = ctx.getConnection();
final SSLConnectionContext sslCtx =
sslBaseFilter.obtainSslConnectionContext(connection);
if (sslCtx.getSslEngine() == null) {
final SSLEngine sslEngine = sslBaseFilter.serverSSLEngineConfigurator.createSSLEngine();
sslEngine.beginHandshake();
sslCtx.configure(sslEngine);
sslBaseFilter.notifyHandshakeStart(connection);
}
ctx.setMessage(allowDispose(allocateInputBuffer(sslCtx)));
return wrappedFilter.handleRead(ctx);
}
@Override
public NextAction handleWrite(FilterChainContext ctx) throws IOException {
return wrappedFilter.handleWrite(ctx);
}
@Override
public NextAction handleEvent(FilterChainContext ctx, FilterChainEvent event) throws IOException {
return wrappedFilter.handleEvent(ctx, event);
}
@Override
public NextAction handleClose(FilterChainContext ctx) throws IOException {
return wrappedFilter.handleClose(ctx);
}
@Override
public void onAdded(FilterChain filterChain) {
wrappedFilter.onAdded(filterChain);
}
@Override
public void onFilterChainChanged(FilterChain filterChain) {
wrappedFilter.onFilterChainChanged(filterChain);
}
@Override
public void onRemoved(FilterChain filterChain) {
wrappedFilter.onRemoved(filterChain);
}
@Override
public void exceptionOccurred(FilterChainContext ctx, Throwable error) {
wrappedFilter.exceptionOccurred(ctx, error);
}
@Override
public FilterChainContext createContext(Connection connection, Operation operation) {
return wrappedFilter.createContext(connection, operation);
}
}
private static final class OnWriteCopyCloner implements MessageCloner<Buffer> {
@Override
public Buffer clone(final Connection connection,
final Buffer originalMessage) {
final SSLConnectionContext sslCtx = getSslConnectionContext(connection);
final int copyThreshold = sslCtx.getNetBufferSize() / 2;
final Buffer lastOutputBuffer = sslCtx.resetLastOutputBuffer();
final int totalRemaining = originalMessage.remaining();
if (totalRemaining < copyThreshold) {
return move(connection.getMemoryManager(),
originalMessage);
}
if (lastOutputBuffer.remaining() < copyThreshold) {
final Buffer tmpBuf =
copy(connection.getMemoryManager(),
originalMessage);
if (originalMessage.isComposite()) {
((CompositeBuffer) originalMessage).replace(
lastOutputBuffer, tmpBuf);
} else {
assert originalMessage == lastOutputBuffer;
}
lastOutputBuffer.tryDispose();
return tmpBuf;
}
return originalMessage;
}
}
}