public class PublishProcesser implements Processer { private static final InternalLogger logger = InternalLoggerFactory.getInstance(PublishProcesser.class); private static DisconnectMessage DISCONNECT = new DisconnectMessage(); public Message proc(Message msg, ChannelHandlerContext ctx) { String clientId = MemPool.getClientId(ctx.channel()); if (clientId == null) { return DISCONNECT; } PublishMessage pm = (PublishMessage) msg; Set<ChannelEntity> channelEntitys = MemPool.getChannelByTopics(pm.getTopic()); if (channelEntitys == null) { return null; } for (ChannelEntity channelEntity : channelEntitys) { logger.debug( "PUBLISH to ChannelEntity topic = " + pm.getTopic() + " payload = " + pm.getDataAsString()); channelEntity.write(pm); } return null; } }
/** * 系统参数配置 * * @throws Exception */ public void initSystem() throws Exception { PropertyConfigurator.configure("Log4j.properties"); InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); log.info(System.getProperty("file.encoding")); System.setProperty( "io.netty.recycler.maxCapacity.default", PropertyUtil.getProperty("io.netty.recycler.maxCapacity.default")); System.setProperty("io.netty.leakDetectionLevel", "paranoid"); DbHelper.init(); }
/** * Abstract base class for {@link EventLoopGroup} implementations that handle their tasks with * multiple threads at the same time. */ public abstract class MultithreadEventLoopGroup extends MultithreadEventExecutorGroup implements EventLoopGroup { private static final InternalLogger logger = InternalLoggerFactory.getInstance(MultithreadEventLoopGroup.class); private static final int DEFAULT_EVENT_LOOP_THREADS; static { DEFAULT_EVENT_LOOP_THREADS = Math.max( 1, SystemPropertyUtil.getInt( "io.netty.eventLoopThreads", Runtime.getRuntime().availableProcessors() * 2)); if (logger.isDebugEnabled()) { logger.debug("-Dio.netty.eventLoopThreads: {}", DEFAULT_EVENT_LOOP_THREADS); } } /** * @see {@link MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, Executor, * Object...)} */ protected MultithreadEventLoopGroup(int nEventLoops, Executor executor, Object... args) { super(nEventLoops == 0 ? DEFAULT_EVENT_LOOP_THREADS : nEventLoops, executor, args); } /** * @see {@link MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, ExecutorFactory, * Object...)} */ protected MultithreadEventLoopGroup( int nEventLoops, ExecutorFactory executorFactory, Object... args) { super(nEventLoops == 0 ? DEFAULT_EVENT_LOOP_THREADS : nEventLoops, executorFactory, args); } @Override public EventLoop next() { return (EventLoop) super.next(); } @Override protected abstract EventLoop newChild(Executor executor, Object... args) throws Exception; @Override public ChannelFuture register(Channel channel) { return next().register(channel); } @Override public ChannelFuture register(Channel channel, ChannelPromise promise) { return next().register(channel, promise); } }
/** * 单一回复编码类 * * @className SingleResponseEncoder * @author jy * @date 2015年08月02日 * @since JDK 1.6 * @see com.foxinmy.weixin4j.server.response.SingleResponse */ public class SingleResponseEncoder extends MessageToMessageEncoder<SingleResponse> { private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass()); @Override protected void encode(ChannelHandlerContext ctx, SingleResponse response, List<Object> out) throws WeixinException { String content = response.toContent(); ctx.writeAndFlush(HttpUtil.createHttpResponse(content, ServerToolkits.CONTENTTYPE$TEXT_PLAIN)); logger.debug("encode single response:{}", content); } }
/** * 微信回复编码类 * * @className WeixinResponseEncoder * @author jy * @date 2014年11月13日 * @since JDK 1.6 * @see <a href="http://mp.weixin.qq.com/wiki/0/61c3a8b9d50ac74f18bdf2e54ddfc4e0.html">加密接入指引</a> * @see com.foxinmy.weixin4j.response.WeixinResponse */ public class WeixinResponseEncoder extends MessageToMessageEncoder<WeixinResponse> { private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass()); @Override protected void encode(ChannelHandlerContext ctx, WeixinResponse response, List<Object> out) throws WeixinException { WeixinMessageTransfer messageTransfer = ctx.channel().attr(ServerToolkits.MESSAGE_TRANSFER_KEY).get(); EncryptType encryptType = messageTransfer.getEncryptType(); StringBuilder content = new StringBuilder(); content.append("<xml>"); content.append( String.format( "<ToUserName><![CDATA[%s]]></ToUserName>", messageTransfer.getFromUserName())); content.append( String.format( "<FromUserName><![CDATA[%s]]></FromUserName>", messageTransfer.getToUserName())); content.append( String.format( "<CreateTime><![CDATA[%d]]></CreateTime>", System.currentTimeMillis() / 1000l)); content.append(String.format("<MsgType><![CDATA[%s]]></MsgType>", response.getMsgType())); content.append(response.toContent()); content.append("</xml>"); if (encryptType == EncryptType.AES) { AesToken aesToken = messageTransfer.getAesToken(); String nonce = ServerToolkits.generateRandomString(32); String timestamp = Long.toString(System.currentTimeMillis() / 1000l); String encrtypt = MessageUtil.aesEncrypt(aesToken.getWeixinId(), aesToken.getAesKey(), content.toString()); String msgSignature = MessageUtil.signature(aesToken.getToken(), nonce, timestamp, encrtypt); content.delete(0, content.length()); content.append("<xml>"); content.append(String.format("<Nonce><![CDATA[%s]]></Nonce>", nonce)); content.append(String.format("<TimeStamp><![CDATA[%s]]></TimeStamp>", timestamp)); content.append(String.format("<MsgSignature><![CDATA[%s]]></MsgSignature>", msgSignature)); content.append(String.format("<Encrypt><![CDATA[%s]]></Encrypt>", encrtypt)); content.append("</xml>"); } ctx.writeAndFlush( HttpUtil.createHttpResponse( content.toString(), OK, ServerToolkits.CONTENTTYPE$APPLICATION_XML)); logger.info("{} encode weixin response:{}", encryptType, content); } }
public final class ResourceLeakDetector<T> { private static final String PROP_LEVEL = "io.netty.leakDetectionLevel"; private static final Level DEFAULT_LEVEL = Level.SIMPLE; private static Level level; public static enum Level { DISABLED, SIMPLE, ADVANCED, PARANOID; private Level() {} } private static final InternalLogger logger = InternalLoggerFactory.getInstance(ResourceLeakDetector.class); private static final int DEFAULT_SAMPLING_INTERVAL = 113; static { boolean disabled; if (SystemPropertyUtil.get("io.netty.noResourceLeakDetection") != null) { boolean disabled = SystemPropertyUtil.getBoolean("io.netty.noResourceLeakDetection", false); logger.debug("-Dio.netty.noResourceLeakDetection: {}", Boolean.valueOf(disabled)); logger.warn( "-Dio.netty.noResourceLeakDetection is deprecated. Use '-D{}={}' instead.", "io.netty.leakDetectionLevel", DEFAULT_LEVEL.name().toLowerCase()); } else { disabled = false; } Level defaultLevel = disabled ? Level.DISABLED : DEFAULT_LEVEL; String levelStr = SystemPropertyUtil.get("io.netty.leakDetectionLevel", defaultLevel.name()) .trim() .toUpperCase(); Level level = DEFAULT_LEVEL; for (Level l : EnumSet.allOf(Level.class)) { if ((levelStr.equals(l.name())) || (levelStr.equals(String.valueOf(l.ordinal())))) { level = l; } } level = level; if (logger.isDebugEnabled()) { logger.debug("-D{}: {}", "io.netty.leakDetectionLevel", level.name().toLowerCase()); } } @Deprecated public static void setEnabled(boolean enabled) { setLevel(enabled ? Level.SIMPLE : Level.DISABLED); } public static boolean isEnabled() { return getLevel().ordinal() > Level.DISABLED.ordinal(); } public static void setLevel(Level level) { if (level == null) { throw new NullPointerException("level"); } level = level; } public static Level getLevel() { return level; } private final ResourceLeakDetector<T>.DefaultResourceLeak head = new DefaultResourceLeak(null); private final ResourceLeakDetector<T>.DefaultResourceLeak tail = new DefaultResourceLeak(null); private final ReferenceQueue<Object> refQueue = new ReferenceQueue(); private final ConcurrentMap<String, Boolean> reportedLeaks = PlatformDependent.newConcurrentHashMap(); private final String resourceType; private final int samplingInterval; private final long maxActive; private long active; private final AtomicBoolean loggedTooManyActive = new AtomicBoolean(); private long leakCheckCnt; public ResourceLeakDetector(Class<?> resourceType) { this(StringUtil.simpleClassName(resourceType)); } public ResourceLeakDetector(String resourceType) { this(resourceType, 113, Long.MAX_VALUE); } public ResourceLeakDetector(Class<?> resourceType, int samplingInterval, long maxActive) { this(StringUtil.simpleClassName(resourceType), samplingInterval, maxActive); } public ResourceLeakDetector(String resourceType, int samplingInterval, long maxActive) { if (resourceType == null) { throw new NullPointerException("resourceType"); } if (samplingInterval <= 0) { throw new IllegalArgumentException( "samplingInterval: " + samplingInterval + " (expected: 1+)"); } if (maxActive <= 0L) { throw new IllegalArgumentException("maxActive: " + maxActive + " (expected: 1+)"); } this.resourceType = resourceType; this.samplingInterval = samplingInterval; this.maxActive = maxActive; this.head.next = this.tail; this.tail.prev = this.head; } public ResourceLeak open(T obj) { Level level = level; if (level == Level.DISABLED) { return null; } if (level.ordinal() < Level.PARANOID.ordinal()) { if (this.leakCheckCnt++ % this.samplingInterval == 0L) { reportLeak(level); return new DefaultResourceLeak(obj); } return null; } reportLeak(level); return new DefaultResourceLeak(obj); } private void reportLeak(Level level) { if (!logger.isErrorEnabled()) { for (; ; ) { ResourceLeakDetector<T>.DefaultResourceLeak ref = (DefaultResourceLeak) this.refQueue.poll(); if (ref == null) { break; } ref.close(); } return; } int samplingInterval = level == Level.PARANOID ? 1 : this.samplingInterval; if ((this.active * samplingInterval > this.maxActive) && (this.loggedTooManyActive.compareAndSet(false, true))) { logger.error( "LEAK: You are creating too many " + this.resourceType + " instances. " + this.resourceType + " is a shared resource that must be reused across the JVM," + "so that only a few instances are created."); } for (; ; ) { ResourceLeakDetector<T>.DefaultResourceLeak ref = (DefaultResourceLeak) this.refQueue.poll(); if (ref == null) { break; } ref.clear(); if (ref.close()) { String records = ref.toString(); if (this.reportedLeaks.putIfAbsent(records, Boolean.TRUE) == null) { if (records.isEmpty()) { logger.error( "LEAK: {}.release() was not called before it's garbage-collected. Enable advanced leak reporting to find out where the leak occurred. To enable advanced leak reporting, specify the JVM option '-D{}={}' or call {}.setLevel()", new Object[] { this.resourceType, "io.netty.leakDetectionLevel", Level.ADVANCED.name().toLowerCase(), StringUtil.simpleClassName(this) }); } else { logger.error( "LEAK: {}.release() was not called before it's garbage-collected.{}", this.resourceType, records); } } } } } private final class DefaultResourceLeak extends PhantomReference<Object> implements ResourceLeak { private static final int MAX_RECORDS = 4; private final String creationRecord; private final Deque<String> lastRecords = new ArrayDeque(); private final AtomicBoolean freed; private ResourceLeakDetector<T>.DefaultResourceLeak prev; private ResourceLeakDetector<T>.DefaultResourceLeak next; DefaultResourceLeak(Object referent) { super(referent != null ? ResourceLeakDetector.this.refQueue : null); ResourceLeakDetector.Level level; if (referent != null) { level = ResourceLeakDetector.getLevel(); if (level.ordinal() >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { this.creationRecord = ResourceLeakDetector.newRecord(3); } else { this.creationRecord = null; } synchronized (ResourceLeakDetector.this.head) { this.prev = ResourceLeakDetector.this.head; this.next = ResourceLeakDetector.this.head.next; ResourceLeakDetector.this.head.next.prev = this; ResourceLeakDetector.this.head.next = this; ResourceLeakDetector.access$408(ResourceLeakDetector.this); } this.freed = new AtomicBoolean(); } else { this.creationRecord = null; this.freed = new AtomicBoolean(true); } } public void record() { if (this.creationRecord != null) { String value = ResourceLeakDetector.newRecord(2); synchronized (this.lastRecords) { int size = this.lastRecords.size(); if ((size == 0) || (!((String) this.lastRecords.getLast()).equals(value))) { this.lastRecords.add(value); } if (size > 4) { this.lastRecords.removeFirst(); } } } } public boolean close() { if (this.freed.compareAndSet(false, true)) { synchronized (ResourceLeakDetector.this.head) { ResourceLeakDetector.access$410(ResourceLeakDetector.this); this.prev.next = this.next; this.next.prev = this.prev; this.prev = null; this.next = null; } return true; } return false; } public String toString() { if (this.creationRecord == null) { return ""; } Object[] array; synchronized (this.lastRecords) { array = this.lastRecords.toArray(); } StringBuilder buf = new StringBuilder(16384); buf.append(StringUtil.NEWLINE); buf.append("Recent access records: "); buf.append(array.length); buf.append(StringUtil.NEWLINE); if (array.length > 0) { for (int i = array.length - 1; i >= 0; i--) { buf.append('#'); buf.append(i + 1); buf.append(':'); buf.append(StringUtil.NEWLINE); buf.append(array[i]); } } buf.append("Created at:"); buf.append(StringUtil.NEWLINE); buf.append(this.creationRecord); buf.setLength(buf.length() - StringUtil.NEWLINE.length()); return buf.toString(); } } private static final String[] STACK_TRACE_ELEMENT_EXCLUSIONS = { "io.netty.buffer.AbstractByteBufAllocator.toLeakAwareBuffer(" }; static String newRecord(int recordsToSkip) { StringBuilder buf = new StringBuilder(4096); StackTraceElement[] array = new Throwable().getStackTrace(); for (StackTraceElement e : array) { if (recordsToSkip > 0) { recordsToSkip--; } else { String estr = e.toString(); boolean excluded = false; for (String exclusion : STACK_TRACE_ELEMENT_EXCLUSIONS) { if (estr.startsWith(exclusion)) { excluded = true; break; } } if (!excluded) { buf.append('\t'); buf.append(estr); buf.append(StringUtil.NEWLINE); } } } return buf.toString(); } }
@Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { InternalLoggerFactory.getInstance(SocketConnectionAttemptTest.class) .warn("Unexpected exception:", cause); }
static { InternalLogger logger = InternalLoggerFactory.getInstance(SocketConnectionAttemptTest.class); logger.debug("-Dio.netty.testsuite.badHost: {}", BAD_HOST); logger.debug("-Dio.netty.testsuite.badPort: {}", BAD_PORT); }
/** A {@link SocketChannel} which is using Old-Blocking-IO */ public class OioSocketChannel extends OioByteStreamChannel implements SocketChannel { private static final InternalLogger logger = InternalLoggerFactory.getInstance(OioSocketChannel.class); private final Socket socket; private final OioSocketChannelConfig config; /** Create a new instance with an new {@link Socket} */ public OioSocketChannel(EventLoop eventLoop) { this(eventLoop, new Socket()); } /** * Create a new instance from the given {@link Socket} * * @param socket the {@link Socket} which is used by this instance */ public OioSocketChannel(EventLoop eventLoop, Socket socket) { this(null, eventLoop, socket); } /** * Create a new instance from the given {@link Socket} * * @param parent the parent {@link Channel} which was used to create this instance. This can be * null if the {@link} has no parent as it was created by your self. * @param socket the {@link Socket} which is used by this instance */ public OioSocketChannel(Channel parent, EventLoop eventLoop, Socket socket) { super(parent, eventLoop); this.socket = socket; config = new DefaultOioSocketChannelConfig(this, socket); boolean success = false; try { if (socket.isConnected()) { activate(socket.getInputStream(), socket.getOutputStream()); } socket.setSoTimeout(SO_TIMEOUT); success = true; } catch (Exception e) { throw new ChannelException("failed to initialize a socket", e); } finally { if (!success) { try { socket.close(); } catch (IOException e) { logger.warn("Failed to close a socket.", e); } } } } @Override public ServerSocketChannel parent() { return (ServerSocketChannel) super.parent(); } @Override public OioSocketChannelConfig config() { return config; } @Override public boolean isOpen() { return !socket.isClosed(); } @Override public boolean isActive() { return !socket.isClosed() && socket.isConnected(); } @Override public boolean isInputShutdown() { return super.isInputShutdown(); } @Override public boolean isOutputShutdown() { return socket.isOutputShutdown() || !isActive(); } @Override public ChannelFuture shutdownOutput() { return shutdownOutput(newPromise()); } @Override protected int doReadBytes(ByteBuf buf) throws Exception { if (socket.isClosed()) { return -1; } try { return super.doReadBytes(buf); } catch (SocketTimeoutException e) { return 0; } } @Override public ChannelFuture shutdownOutput(final ChannelPromise future) { EventLoop loop = eventLoop(); if (loop.inEventLoop()) { try { socket.shutdownOutput(); future.setSuccess(); } catch (Throwable t) { future.setFailure(t); } } else { loop.execute( new Runnable() { @Override public void run() { shutdownOutput(future); } }); } return future; } @Override public InetSocketAddress localAddress() { return (InetSocketAddress) super.localAddress(); } @Override public InetSocketAddress remoteAddress() { return (InetSocketAddress) super.remoteAddress(); } @Override protected SocketAddress localAddress0() { return socket.getLocalSocketAddress(); } @Override protected SocketAddress remoteAddress0() { return socket.getRemoteSocketAddress(); } @Override protected void doBind(SocketAddress localAddress) throws Exception { socket.bind(localAddress); } @Override protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { if (localAddress != null) { socket.bind(localAddress); } boolean success = false; try { socket.connect(remoteAddress, config().getConnectTimeoutMillis()); activate(socket.getInputStream(), socket.getOutputStream()); success = true; } catch (SocketTimeoutException e) { ConnectTimeoutException cause = new ConnectTimeoutException("connection timed out: " + remoteAddress); cause.setStackTrace(e.getStackTrace()); throw cause; } finally { if (!success) { doClose(); } } } @Override protected void doDisconnect() throws Exception { doClose(); } @Override protected void doClose() throws Exception { socket.close(); } @Override protected boolean checkInputShutdown() { if (isInputShutdown()) { try { Thread.sleep(config().getSoTimeout()); } catch (Throwable e) { // ignore } return true; } return false; } }
static { ResourceLeakDetector.setLevel(Level.PARANOID); InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); // need to create a test logger that searches the log output for leaks // and fail the test and kill the world }
/** * Adds <a href="http://en.wikipedia.org/wiki/Transport_Layer_Security">SSL · TLS</a> and * StartTLS support to a {@link Channel}. Please refer to the <strong>"SecureChat"</strong> example * in the distribution or the web site for the detailed usage. * * <h3>Beginning the handshake</h3> * * <p>You must make sure not to write a message while the handshake is in progress unless you are * renegotiating. You will be notified by the {@link Future} which is returned by the {@link * #handshakeFuture()} method when the handshake process succeeds or fails. * * <p>Beside using the handshake {@link ChannelFuture} to get notified about the completation of the * handshake it's also possible to detect it by implement the {@link * ChannelHandler#userEventTriggered(ChannelHandlerContext, Object)} method and check for a {@link * SslHandshakeCompletionEvent}. * * <h3>Handshake</h3> * * <p>The handshake will be automaticly issued for you once the {@link Channel} is active and {@link * SSLEngine#getUseClientMode()} returns {@code true}. So no need to bother with it by your self. * * <h3>Closing the session</h3> * * <p>To close the SSL session, the {@link #close()} method should be called to send the {@code * close_notify} message to the remote peer. One exception is when you close the {@link Channel} - * {@link SslHandler} intercepts the close request and send the {@code close_notify} message before * the channel closure automatically. Once the SSL session is closed, it is not reusable, and * consequently you should create a new {@link SslHandler} with a new {@link SSLEngine} as explained * in the following section. * * <h3>Restarting the session</h3> * * <p>To restart the SSL session, you must remove the existing closed {@link SslHandler} from the * {@link ChannelPipeline}, insert a new {@link SslHandler} with a new {@link SSLEngine} into the * pipeline, and start the handshake process as described in the first section. * * <h3>Implementing StartTLS</h3> * * <p><a href="http://en.wikipedia.org/wiki/STARTTLS">StartTLS</a> is the communication pattern that * secures the wire in the middle of the plaintext connection. Please note that it is different from * SSL · TLS, that secures the wire from the beginning of the connection. Typically, StartTLS * is composed of three steps: * * <ol> * <li>Client sends a StartTLS request to server. * <li>Server sends a StartTLS response to client. * <li>Client begins SSL handshake. * </ol> * * If you implement a server, you need to: * * <ol> * <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code true}, * <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and * <li>write a StartTLS response. * </ol> * * Please note that you must insert {@link SslHandler} <em>before</em> sending the StartTLS * response. Otherwise the client can send begin SSL handshake before {@link SslHandler} is inserted * to the {@link ChannelPipeline}, causing data corruption. * * <p>The client-side implementation is much simpler. * * <ol> * <li>Write a StartTLS request, * <li>wait for the StartTLS response, * <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code false}, * <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and * <li>Initiate SSL handshake. * </ol> * * <h3>Known issues</h3> * * <p>Because of a known issue with the current implementation of the SslEngine that comes with Java * it may be possible that you see blocked IO-Threads while a full GC is done. * * <p>So if you are affected you can workaround this problem by adjust the cache settings like shown * below: * * <pre> * SslContext context = ...; * context.getServerSessionContext().setSessionCacheSize(someSaneSize); * context.getServerSessionContext().setSessionTime(someSameTimeout); * </pre> * * <p>What values to use here depends on the nature of your application and should be set based on * monitoring and debugging of it. For more details see <a * href="https://github.com/netty/netty/issues/832">#832</a> in our issue tracker. */ public class SslHandler extends ByteToMessageDecoder { private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslHandler.class); private static final Pattern IGNORABLE_CLASS_IN_STACK = Pattern.compile("^.*(?:Socket|Datagram|Sctp|Udt)Channel.*$"); private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile( "^.*(?:connection.*(?:reset|closed|abort|broken)|broken.*pipe).*$", Pattern.CASE_INSENSITIVE); /** * Used in {@link #unwrapNonAppData(ChannelHandlerContext)} as input for {@link * #unwrap(ChannelHandlerContext, ByteBuf, int, int)}. Using this static instance reduce object * creation as {@link Unpooled#EMPTY_BUFFER#nioBuffer()} creates a new {@link ByteBuffer} * everytime. */ private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already"); private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out"); private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException(); static { SSLENGINE_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); HANDSHAKE_TIMED_OUT.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); CHANNEL_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); } private volatile ChannelHandlerContext ctx; private final SSLEngine engine; private final int maxPacketBufferSize; /** * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link * SSLEngine#unwrap(ByteBuffer, ByteBuffer[])} should be called with a {@link ByteBuf} that is * only backed by one {@link ByteBuffer} to reduce the object creation. */ private final ByteBuffer[] singleBuffer = new ByteBuffer[1]; // BEGIN Platform-dependent flags /** {@code true} if and only if {@link SSLEngine} expects a direct buffer. */ private final boolean wantsDirectBuffer; /** * {@code true} if and only if {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} requires the output * buffer to be always as large as {@link #maxPacketBufferSize} even if the input buffer contains * small amount of data. * * <p>If this flag is {@code false}, we allocate a smaller output buffer. */ private final boolean wantsLargeOutboundNetworkBuffer; /** * {@code true} if and only if {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} expects a heap * buffer rather than a direct buffer. For an unknown reason, JDK8 SSLEngine causes JVM to crash * when its cipher suite uses Galois Counter Mode (GCM). */ private boolean wantsInboundHeapBuffer; // END Platform-dependent flags private final boolean startTls; private boolean sentFirstMessage; private boolean flushedBeforeHandshake; private boolean readDuringHandshake; private PendingWriteQueue pendingUnencryptedWrites; private Promise<Channel> handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); /** * Set by wrap*() methods when something is produced. {@link * #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call * ctx.flush(). */ private boolean needsFlush; private int packetLength; /** * This flag is used to determine if we need to call {@link ChannelHandlerContext#read()} to * consume more data when {@link ChannelConfig#isAutoRead()} is {@code false}. */ private boolean firedChannelRead; private volatile long handshakeTimeoutMillis = 10000; private volatile long closeNotifyTimeoutMillis = 3000; /** * Creates a new instance. * * @param engine the {@link SSLEngine} this handler will use */ public SslHandler(SSLEngine engine) { this(engine, false); } /** * Creates a new instance. * * @param engine the {@link SSLEngine} this handler will use * @param startTls {@code true} if the first write request shouldn't be encrypted by the {@link * SSLEngine} */ public SslHandler(SSLEngine engine, boolean startTls) { if (engine == null) { throw new NullPointerException("engine"); } this.engine = engine; this.startTls = startTls; maxPacketBufferSize = engine.getSession().getPacketBufferSize(); boolean opensslEngine = engine instanceof OpenSslEngine; wantsDirectBuffer = opensslEngine; wantsLargeOutboundNetworkBuffer = !opensslEngine; /** * When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with * one {@link ByteBuffer}. * * <p>When using {@link OpenSslEngine}, we can use {@link #COMPOSITE_CUMULATOR} because it has * {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} which works with multiple {@link * ByteBuffer}s and which does not need to do extra memory copies. */ setCumulator(opensslEngine ? COMPOSITE_CUMULATOR : MERGE_CUMULATOR); } public long getHandshakeTimeoutMillis() { return handshakeTimeoutMillis; } public void setHandshakeTimeout(long handshakeTimeout, TimeUnit unit) { if (unit == null) { throw new NullPointerException("unit"); } setHandshakeTimeoutMillis(unit.toMillis(handshakeTimeout)); } public void setHandshakeTimeoutMillis(long handshakeTimeoutMillis) { if (handshakeTimeoutMillis < 0) { throw new IllegalArgumentException( "handshakeTimeoutMillis: " + handshakeTimeoutMillis + " (expected: >= 0)"); } this.handshakeTimeoutMillis = handshakeTimeoutMillis; } public long getCloseNotifyTimeoutMillis() { return closeNotifyTimeoutMillis; } public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) { if (unit == null) { throw new NullPointerException("unit"); } setCloseNotifyTimeoutMillis(unit.toMillis(closeNotifyTimeout)); } public void setCloseNotifyTimeoutMillis(long closeNotifyTimeoutMillis) { if (closeNotifyTimeoutMillis < 0) { throw new IllegalArgumentException( "closeNotifyTimeoutMillis: " + closeNotifyTimeoutMillis + " (expected: >= 0)"); } this.closeNotifyTimeoutMillis = closeNotifyTimeoutMillis; } /** Returns the {@link SSLEngine} which is used by this handler. */ public SSLEngine engine() { return engine; } /** * Returns the name of the current application-level protocol. * * @return the protocol name or {@code null} if application-level protocol has not been negotiated */ public String applicationProtocol() { SSLSession sess = engine().getSession(); if (!(sess instanceof ApplicationProtocolAccessor)) { return null; } return ((ApplicationProtocolAccessor) sess).getApplicationProtocol(); } /** * Returns a {@link Future} that will get notified once the current TLS handshake completes. * * @return the {@link Future} for the iniital TLS handshake if {@link #renegotiate()} was not * invoked. The {@link Future} for the most recent {@linkplain #renegotiate() TLS * renegotiation} otherwise. */ public Future<Channel> handshakeFuture() { return handshakePromise; } /** * Sends an SSL {@code close_notify} message to the specified channel and destroys the underlying * {@link SSLEngine}. */ public ChannelFuture close() { return close(ctx.newPromise()); } /** See {@link #close()} */ public ChannelFuture close(final ChannelPromise future) { final ChannelHandlerContext ctx = this.ctx; ctx.executor() .execute( new Runnable() { @Override public void run() { engine.closeOutbound(); try { write(ctx, Unpooled.EMPTY_BUFFER, future); flush(ctx); } catch (Exception e) { if (!future.tryFailure(e)) { logger.warn("{} flush() raised a masked exception.", ctx.channel(), e); } } } }); return future; } /** * Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is * closed. * * <p>This method will return the same {@link Future} all the time. * * @see SSLEngine */ public Future<Channel> sslCloseFuture() { return sslCloseFuture; } @Override public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { if (!pendingUnencryptedWrites.isEmpty()) { // Check if queue is not empty first because create a new ChannelException is expensive pendingUnencryptedWrites.removeAndFailAll( new ChannelException("Pending write on removal of SslHandler")); } } @Override public void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { closeOutboundAndChannel(ctx, promise, true); } @Override public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { closeOutboundAndChannel(ctx, promise, false); } @Override public void read(ChannelHandlerContext ctx) throws Exception { if (!handshakePromise.isDone()) { readDuringHandshake = true; } ctx.read(); } @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (!(msg instanceof ByteBuf)) { promise.setFailure(new UnsupportedMessageTypeException(msg, ByteBuf.class)); return; } pendingUnencryptedWrites.add(msg, promise); } @Override public void flush(ChannelHandlerContext ctx) throws Exception { // Do not encrypt the first write request if this handler is // created with startTLS flag turned on. if (startTls && !sentFirstMessage) { sentFirstMessage = true; pendingUnencryptedWrites.removeAndWriteAll(); ctx.flush(); return; } if (pendingUnencryptedWrites.isEmpty()) { // It's important to NOT use a voidPromise here as the user // may want to add a ChannelFutureListener to the ChannelPromise later. // // See https://github.com/netty/netty/issues/3364 pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, ctx.newPromise()); } if (!handshakePromise.isDone()) { flushedBeforeHandshake = true; } wrap(ctx, false); ctx.flush(); } private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { ByteBuf out = null; ChannelPromise promise = null; ByteBufAllocator alloc = ctx.alloc(); try { for (; ; ) { Object msg = pendingUnencryptedWrites.current(); if (msg == null) { break; } ByteBuf buf = (ByteBuf) msg; if (out == null) { out = allocateOutNetBuf(ctx, buf.readableBytes()); } SSLEngineResult result = wrap(alloc, engine, buf, out); if (!buf.isReadable()) { promise = pendingUnencryptedWrites.remove(); } else { promise = null; } if (result.getStatus() == Status.CLOSED) { // SSLEngine has been closed already. // Any further write attempts should be denied. pendingUnencryptedWrites.removeAndFailAll(SSLENGINE_CLOSED); return; } else { switch (result.getHandshakeStatus()) { case NEED_TASK: runDelegatedTasks(); break; case FINISHED: setHandshakeSuccess(); // deliberate fall-through case NOT_HANDSHAKING: setHandshakeSuccessIfStillHandshaking(); // deliberate fall-through case NEED_WRAP: finishWrap(ctx, out, promise, inUnwrap); promise = null; out = null; break; case NEED_UNWRAP: return; default: throw new IllegalStateException( "Unknown handshake status: " + result.getHandshakeStatus()); } } } } catch (SSLException e) { setHandshakeFailure(ctx, e); throw e; } finally { finishWrap(ctx, out, promise, inUnwrap); } } private void finishWrap( ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) { if (out == null) { out = Unpooled.EMPTY_BUFFER; } else if (!out.isReadable()) { out.release(); out = Unpooled.EMPTY_BUFFER; } if (promise != null) { ctx.write(out, promise); } else { ctx.write(out); } if (inUnwrap) { needsFlush = true; } } private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { ByteBuf out = null; ByteBufAllocator alloc = ctx.alloc(); try { for (; ; ) { if (out == null) { out = allocateOutNetBuf(ctx, 0); } SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); if (result.bytesProduced() > 0) { ctx.write(out); if (inUnwrap) { needsFlush = true; } out = null; } switch (result.getHandshakeStatus()) { case FINISHED: setHandshakeSuccess(); break; case NEED_TASK: runDelegatedTasks(); break; case NEED_UNWRAP: if (!inUnwrap) { unwrapNonAppData(ctx); } break; case NEED_WRAP: break; case NOT_HANDSHAKING: setHandshakeSuccessIfStillHandshaking(); // Workaround for TLS False Start problem reported at: // https://github.com/netty/netty/issues/1108#issuecomment-14266970 if (!inUnwrap) { unwrapNonAppData(ctx); } break; default: throw new IllegalStateException( "Unknown handshake status: " + result.getHandshakeStatus()); } if (result.bytesProduced() == 0) { break; } // It should not consume empty buffers when it is not handshaking // Fix for Android, where it was encrypting empty buffers even when not handshaking if (result.bytesConsumed() == 0 && result.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) { break; } } } catch (SSLException e) { setHandshakeFailure(ctx, e); throw e; } finally { if (out != null) { out.release(); } } } private SSLEngineResult wrap(ByteBufAllocator alloc, SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException { ByteBuf newDirectIn = null; try { int readerIndex = in.readerIndex(); int readableBytes = in.readableBytes(); // We will call SslEngine.wrap(ByteBuffer[], ByteBuffer) to allow efficient handling of // CompositeByteBuf without force an extra memory copy when CompositeByteBuffer.nioBuffer() is // called. final ByteBuffer[] in0; if (in.isDirect() || !wantsDirectBuffer) { // As CompositeByteBuf.nioBufferCount() can be expensive (as it needs to check all composed // ByteBuf // to calculate the count) we will just assume a CompositeByteBuf contains more then 1 // ByteBuf. // The worst that can happen is that we allocate an extra ByteBuffer[] in // CompositeByteBuf.nioBuffers() // which is better then walking the composed ByteBuf in most cases. if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) { in0 = singleBuffer; // We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object // allocation // to a minimum. in0[0] = in.internalNioBuffer(readerIndex, readableBytes); } else { in0 = in.nioBuffers(); } } else { // We could even go further here and check if its a CompositeByteBuf and if so try to // decompose it and // only replace the ByteBuffer that are not direct. At the moment we just will replace the // whole // CompositeByteBuf to keep the complexity to a minimum newDirectIn = alloc.directBuffer(readableBytes); newDirectIn.writeBytes(in, readerIndex, readableBytes); in0 = singleBuffer; in0[0] = newDirectIn.internalNioBuffer(0, readableBytes); } for (; ; ) { ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); SSLEngineResult result = engine.wrap(in0, out0); in.skipBytes(result.bytesConsumed()); out.writerIndex(out.writerIndex() + result.bytesProduced()); switch (result.getStatus()) { case BUFFER_OVERFLOW: out.ensureWritable(maxPacketBufferSize); break; default: return result; } } } finally { // Null out to allow GC of ByteBuffer singleBuffer[0] = null; if (newDirectIn != null) { newDirectIn.release(); } } } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { // Make sure to release SSLEngine, // and notify the handshake future if the connection has been closed during handshake. setHandshakeFailure(ctx, CHANNEL_CLOSED); super.channelInactive(ctx); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (ignoreException(cause)) { // It is safe to ignore the 'connection reset by peer' or // 'broken pipe' error after sending close_notify. if (logger.isDebugEnabled()) { logger.debug( "{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " + "while writing close_notify in response to the peer's close_notify", ctx.channel(), cause); } // Close the connection explicitly just in case the transport // did not close the connection automatically. if (ctx.channel().isActive()) { ctx.close(); } } else { ctx.fireExceptionCaught(cause); } } /** * Checks if the given {@link Throwable} can be ignore and just "swallowed" * * <p>When an ssl connection is closed a close_notify message is sent. After that the peer also * sends close_notify however, it's not mandatory to receive the close_notify. The party who sent * the initial close_notify can close the connection immediately then the peer will get connection * reset error. */ private boolean ignoreException(Throwable t) { if (!(t instanceof SSLException) && t instanceof IOException && sslCloseFuture.isDone()) { String message = String.valueOf(t.getMessage()).toLowerCase(); // first try to match connection reset / broke peer based on the regex. This is the fastest // way // but may fail on different jdk impls or OS's if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { return true; } // Inspect the StackTraceElements to see if it was a connection reset / broken pipe or not StackTraceElement[] elements = t.getStackTrace(); for (StackTraceElement element : elements) { String classname = element.getClassName(); String methodname = element.getMethodName(); // skip all classes that belong to the io.netty package if (classname.startsWith("io.netty.")) { continue; } // check if the method name is read if not skip it if (!"read".equals(methodname)) { continue; } // This will also match against SocketInputStream which is used by openjdk 7 and maybe // also others if (IGNORABLE_CLASS_IN_STACK.matcher(classname).matches()) { return true; } try { // No match by now.. Try to load the class via classloader and inspect it. // This is mainly done as other JDK implementations may differ in name of // the impl. Class<?> clazz = PlatformDependent.getClassLoader(getClass()).loadClass(classname); if (SocketChannel.class.isAssignableFrom(clazz) || DatagramChannel.class.isAssignableFrom(clazz)) { return true; } // also match against SctpChannel via String matching as it may not present. if (PlatformDependent.javaVersion() >= 7 && "com.sun.nio.sctp.SctpChannel".equals(clazz.getSuperclass().getName())) { return true; } } catch (ClassNotFoundException e) { // This should not happen just ignore } } } return false; } /** * Returns {@code true} if the given {@link ByteBuf} is encrypted. Be aware that this method will * not increase the readerIndex of the given {@link ByteBuf}. * * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to * read, otherwise it will throw an {@link IllegalArgumentException}. * @return encrypted {@code true} if the {@link ByteBuf} is encrypted, {@code false} otherwise. * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5 * bytes to read. */ public static boolean isEncrypted(ByteBuf buffer) { if (buffer.readableBytes() < 5) { throw new IllegalArgumentException("buffer must have at least 5 readable bytes"); } return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1; } /** * Return how much bytes can be read out of the encrypted data. Be aware that this method will not * increase the readerIndex of the given {@link ByteBuf}. * * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to * read, otherwise it will throw an {@link IllegalArgumentException}. * @return length The length of the encrypted packet that is included in the buffer. This will * return {@code -1} if the given {@link ByteBuf} is not encrypted at all. * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5 * bytes to read. */ private static int getEncryptedPacketLength(ByteBuf buffer, int offset) { int packetLength = 0; // SSLv3 or TLS - Check ContentType boolean tls; switch (buffer.getUnsignedByte(offset)) { case 20: // change_cipher_spec case 21: // alert case 22: // handshake case 23: // application_data tls = true; break; default: // SSLv2 or bad data tls = false; } if (tls) { // SSLv3 or TLS - Check ProtocolVersion int majorVersion = buffer.getUnsignedByte(offset + 1); if (majorVersion == 3) { // SSLv3 or TLS packetLength = buffer.getUnsignedShort(offset + 3) + 5; if (packetLength <= 5) { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; } } else { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; } } if (!tls) { // SSLv2 or bad data - Check the version boolean sslv2 = true; int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3; int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1); if (majorVersion == 2 || majorVersion == 3) { // SSLv2 if (headerLength == 2) { packetLength = (buffer.getShort(offset) & 0x7FFF) + 2; } else { packetLength = (buffer.getShort(offset) & 0x3FFF) + 3; } if (packetLength <= headerLength) { sslv2 = false; } } else { sslv2 = false; } if (!sslv2) { return -1; } } return packetLength; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { final int startOffset = in.readerIndex(); final int endOffset = in.writerIndex(); int offset = startOffset; int totalLength = 0; // If we calculated the length of the current SSL record before, use that information. if (packetLength > 0) { if (endOffset - startOffset < packetLength) { return; } else { offset += packetLength; totalLength = packetLength; packetLength = 0; } } boolean nonSslRecord = false; while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { final int readableBytes = endOffset - offset; if (readableBytes < 5) { break; } final int packetLength = getEncryptedPacketLength(in, offset); if (packetLength == -1) { nonSslRecord = true; break; } assert packetLength > 0; if (packetLength > readableBytes) { // wait until the whole packet can be read this.packetLength = packetLength; break; } int newTotalLength = totalLength + packetLength; if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { // Don't read too much. break; } // We have a whole packet. // Increment the offset to handle the next packet. offset += packetLength; totalLength = newTotalLength; } if (totalLength > 0) { boolean decoded = false; // The buffer contains one or more full SSL records. // Slice out the whole packet so unwrap will only be called with complete packets. // Also directly reset the packetLength. This is needed as unwrap(..) may trigger // decode(...) again via: // 1) unwrap(..) is called // 2) wrap(...) is called from within unwrap(...) // 3) wrap(...) calls unwrapLater(...) // 4) unwrapLater(...) calls decode(...) // // See https://github.com/netty/netty/issues/1534 in.skipBytes(totalLength); // If SSLEngine expects a heap buffer for unwrapping, do the conversion. if (in.isDirect() && wantsInboundHeapBuffer) { ByteBuf copy = ctx.alloc().heapBuffer(totalLength); try { copy.writeBytes(in, startOffset, totalLength); decoded = unwrap(ctx, copy, 0, totalLength); } finally { copy.release(); } } else { decoded = unwrap(ctx, in, startOffset, totalLength); } if (!firedChannelRead) { // Check first if firedChannelRead is not set yet as it may have been set in a // previous decode(...) call. firedChannelRead = decoded; } } if (nonSslRecord) { // Not an SSL/TLS packet NotSslRecordException e = new NotSslRecordException("not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(in.readableBytes()); ctx.fireExceptionCaught(e); setHandshakeFailure(ctx, e); } } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { // Discard bytes of the cumulation buffer if needed. discardSomeReadBytes(); if (needsFlush) { needsFlush = false; ctx.flush(); } // If handshake is not finished yet, we need more data. if (!ctx.channel().config().isAutoRead() && (!firedChannelRead || !handshakePromise.isDone())) { // No auto-read used and no message passed through the ChannelPipeline or the handhshake was // not complete // yet, which means we need to trigger the read to ensure we not encounter any stalls. ctx.read(); } firedChannelRead = false; ctx.fireChannelReadComplete(); } /** * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle * handshakes, etc. */ private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException { unwrap(ctx, Unpooled.EMPTY_BUFFER, 0, 0); } /** Unwraps inbound SSL records. */ private boolean unwrap(ChannelHandlerContext ctx, ByteBuf packet, int offset, int length) throws SSLException { boolean decoded = false; boolean wrapLater = false; boolean notifyClosure = false; ByteBuf decodeOut = allocate(ctx, length); try { for (; ; ) { final SSLEngineResult result = unwrap(engine, packet, offset, length, decodeOut); final Status status = result.getStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final int produced = result.bytesProduced(); final int consumed = result.bytesConsumed(); // Update indexes for the next iteration offset += consumed; length -= consumed; switch (status) { case BUFFER_OVERFLOW: int readableBytes = decodeOut.readableBytes(); if (readableBytes > 0) { decoded = true; ctx.fireChannelRead(decodeOut); } else { decodeOut.release(); } // Allocate a new buffer which can hold all the rest data and loop again. // TODO: We may want to reconsider how we calculate the length here as we may // have more then one ssl message to decode. decodeOut = allocate(ctx, engine.getSession().getApplicationBufferSize() - readableBytes); continue; case CLOSED: // notify about the CLOSED state of the SSLEngine. See #137 notifyClosure = true; break; default: break; } switch (handshakeStatus) { case NEED_UNWRAP: break; case NEED_WRAP: wrapNonAppData(ctx, true); break; case NEED_TASK: runDelegatedTasks(); break; case FINISHED: setHandshakeSuccess(); wrapLater = true; continue; case NOT_HANDSHAKING: if (setHandshakeSuccessIfStillHandshaking()) { wrapLater = true; continue; } if (flushedBeforeHandshake) { // We need to call wrap(...) in case there was a flush done before the handshake // completed. // // See https://github.com/netty/netty/pull/2437 flushedBeforeHandshake = false; wrapLater = true; } break; default: throw new IllegalStateException("unknown handshake status: " + handshakeStatus); } if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) { break; } } if (wrapLater) { wrap(ctx, true); } if (notifyClosure) { sslCloseFuture.trySuccess(ctx.channel()); } } catch (SSLException e) { setHandshakeFailure(ctx, e); throw e; } finally { if (decodeOut.isReadable()) { decoded = true; ctx.fireChannelRead(decodeOut); } else { decodeOut.release(); } } return decoded; } private SSLEngineResult unwrap( SSLEngine engine, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException { int nioBufferCount = in.nioBufferCount(); int writerIndex = out.writerIndex(); final SSLEngineResult result; if (engine instanceof OpenSslEngine && nioBufferCount > 1) { /** * If {@link OpenSslEngine} is in use, we can use a special {@link * OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method that accepts multiple {@link * ByteBuffer}s without additional memory copies. */ OpenSslEngine opensslEngine = (OpenSslEngine) engine; try { singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes()); result = opensslEngine.unwrap(in.nioBuffers(readerIndex, len), singleBuffer); out.writerIndex(writerIndex + result.bytesProduced()); } finally { singleBuffer[0] = null; } } else { result = engine.unwrap( toByteBuffer(in, readerIndex, len), toByteBuffer(out, writerIndex, out.writableBytes())); } out.writerIndex(writerIndex + result.bytesProduced()); return result; } private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) { return out.nioBufferCount() == 1 ? out.internalNioBuffer(index, len) : out.nioBuffer(index, len); } /** * Fetches all delegated tasks from the {@link SSLEngine} and runs them by invoking them directly. */ private void runDelegatedTasks() { for (; ; ) { Runnable task = engine.getDelegatedTask(); if (task == null) { break; } task.run(); } } /** * Works around some Android {@link SSLEngine} implementations that skip {@link * HandshakeStatus#FINISHED} and go straight into {@link HandshakeStatus#NOT_HANDSHAKING} when * handshake is finished. * * @return {@code true} if and only if the workaround has been applied and thus {@link * #handshakeFuture} has been marked as success by this method */ private boolean setHandshakeSuccessIfStillHandshaking() { if (!handshakePromise.isDone()) { setHandshakeSuccess(); return true; } return false; } /** Notify all the handshake futures about the successfully handshake */ private void setHandshakeSuccess() { // Work around the JVM crash which occurs when a cipher suite with GCM enabled. final String cipherSuite = String.valueOf(engine.getSession().getCipherSuite()); if (!wantsDirectBuffer && (cipherSuite.contains("_GCM_") || cipherSuite.contains("-GCM-"))) { wantsInboundHeapBuffer = true; } handshakePromise.trySuccess(ctx.channel()); if (logger.isDebugEnabled()) { logger.debug("{} HANDSHAKEN: {}", ctx.channel(), engine.getSession().getCipherSuite()); } ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); if (readDuringHandshake && !ctx.channel().config().isAutoRead()) { readDuringHandshake = false; ctx.read(); } } /** Notify all the handshake futures about the failure during the handshake. */ private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { // Release all resources such as internal buffers that SSLEngine // is managing. engine.closeOutbound(); try { engine.closeInbound(); } catch (SSLException e) { // only log in debug mode as it most likely harmless and latest chrome still trigger // this all the time. // // See https://github.com/netty/netty/issues/1340 String msg = e.getMessage(); if (msg == null || !msg.contains("possible truncation attack")) { logger.debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); } } notifyHandshakeFailure(cause); pendingUnencryptedWrites.removeAndFailAll(cause); } private void notifyHandshakeFailure(Throwable cause) { if (handshakePromise.tryFailure(cause)) { ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); ctx.close(); } } private void closeOutboundAndChannel( final ChannelHandlerContext ctx, final ChannelPromise promise, boolean disconnect) throws Exception { if (!ctx.channel().isActive()) { if (disconnect) { ctx.disconnect(promise); } else { ctx.close(promise); } return; } engine.closeOutbound(); ChannelPromise closeNotifyFuture = ctx.newPromise(); write(ctx, Unpooled.EMPTY_BUFFER, closeNotifyFuture); flush(ctx); safeClose(ctx, closeNotifyFuture, promise); } @Override public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; pendingUnencryptedWrites = new PendingWriteQueue(ctx); if (ctx.channel().isActive() && engine.getUseClientMode()) { // Begin the initial handshake. // channelActive() event has been fired already, which means this.channelActive() will // not be invoked. We have to initialize here instead. handshake(null); } else { // channelActive() event has not been fired yet. this.channelOpen() will be invoked // and initialization will occur there. } } /** Performs TLS renegotiation. */ public Future<Channel> renegotiate() { ChannelHandlerContext ctx = this.ctx; if (ctx == null) { throw new IllegalStateException(); } return renegotiate(ctx.executor().<Channel>newPromise()); } /** Performs TLS renegotiation. */ public Future<Channel> renegotiate(final Promise<Channel> promise) { if (promise == null) { throw new NullPointerException("promise"); } ChannelHandlerContext ctx = this.ctx; if (ctx == null) { throw new IllegalStateException(); } EventExecutor executor = ctx.executor(); if (!executor.inEventLoop()) { executor.execute( new OneTimeTask() { @Override public void run() { handshake(promise); } }); return promise; } handshake(promise); return promise; } /** * Performs TLS (re)negotiation. * * @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise}, * assuming that the current negotiation has not been finished. Currently, {@code null} is * expected only for the initial handshake. */ private void handshake(final Promise<Channel> newHandshakePromise) { final Promise<Channel> p; if (newHandshakePromise != null) { final Promise<Channel> oldHandshakePromise = handshakePromise; if (!oldHandshakePromise.isDone()) { // There's no need to handshake because handshake is in progress already. // Merge the new promise into the old one. oldHandshakePromise.addListener( new FutureListener<Channel>() { @Override public void operationComplete(Future<Channel> future) throws Exception { if (future.isSuccess()) { newHandshakePromise.setSuccess(future.getNow()); } else { newHandshakePromise.setFailure(future.cause()); } } }); return; } handshakePromise = p = newHandshakePromise; } else { // Forced to reuse the old handshake. p = handshakePromise; assert !p.isDone(); } // Begin handshake. final ChannelHandlerContext ctx = this.ctx; try { engine.beginHandshake(); wrapNonAppData(ctx, false); ctx.flush(); } catch (Exception e) { notifyHandshakeFailure(e); } // Set timeout if necessary. final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; if (handshakeTimeoutMillis <= 0 || p.isDone()) { return; } final ScheduledFuture<?> timeoutFuture = ctx.executor() .schedule( new Runnable() { @Override public void run() { if (p.isDone()) { return; } notifyHandshakeFailure(HANDSHAKE_TIMED_OUT); } }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); // Cancel the handshake timeout when handshake is finished. p.addListener( new FutureListener<Channel>() { @Override public void operationComplete(Future<Channel> f) throws Exception { timeoutFuture.cancel(false); } }); } /** Issues an initial TLS handshake once connected when used in client-mode */ @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { if (!startTls && engine.getUseClientMode()) { // Begin the initial handshake handshake(null); } ctx.fireChannelActive(); } private void safeClose( final ChannelHandlerContext ctx, ChannelFuture flushFuture, final ChannelPromise promise) { if (!ctx.channel().isActive()) { ctx.close(promise); return; } final ScheduledFuture<?> timeoutFuture; if (closeNotifyTimeoutMillis > 0) { // Force-close the connection if close_notify is not fully sent in time. timeoutFuture = ctx.executor() .schedule( new Runnable() { @Override public void run() { logger.warn( "{} Last write attempt timed out; force-closing the connection.", ctx.channel()); // We notify the promise in the TryNotifyListener as there is a "race" where // the close(...) call // by the timeoutFuture and the close call in the flushFuture listener will be // called. Because of // this we need to use trySuccess() and tryFailure(...) as otherwise we can // cause an // IllegalStateException. ctx.close(ctx.newPromise()).addListener(new ChannelPromiseNotifier(promise)); } }, closeNotifyTimeoutMillis, TimeUnit.MILLISECONDS); } else { timeoutFuture = null; } // Close the connection if close_notify is sent in time. flushFuture.addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture f) throws Exception { if (timeoutFuture != null) { timeoutFuture.cancel(false); } // Trigger the close in all cases to make sure the promise is notified // See https://github.com/netty/netty/issues/2358 // // We notify the promise in the ChannelPromiseNotifier as there is a "race" where the // close(...) call // by the timeoutFuture and the close call in the flushFuture listener will be called. // Because of // this we need to use trySuccess() and tryFailure(...) as otherwise we can cause an // IllegalStateException. ctx.close(ctx.newPromise()).addListener(new ChannelPromiseNotifier(promise)); } }); } /** * Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies * in {@link OpenSslEngine}. */ private ByteBuf allocate(ChannelHandlerContext ctx, int capacity) { ByteBufAllocator alloc = ctx.alloc(); if (wantsDirectBuffer) { return alloc.directBuffer(capacity); } else { return alloc.buffer(capacity); } } /** * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which * can encrypt the specified amount of pending bytes. */ private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) { if (wantsLargeOutboundNetworkBuffer) { return allocate(ctx, maxPacketBufferSize); } else { return allocate( ctx, Math.min( pendingBytes + OpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH, maxPacketBufferSize)); } } private final class LazyChannelPromise extends DefaultPromise<Channel> { @Override protected EventExecutor executor() { if (ctx == null) { throw new IllegalStateException(); } return ctx.executor(); } @Override protected void checkDeadLock() { if (ctx == null) { // If ctx is null the handlerAdded(...) callback was not called, in this case the // checkDeadLock() // method was called from another Thread then the one that is used by ctx.executor(). We // need to // guard against this as a user can see a race if handshakeFuture().sync() is called but the // handlerAdded(..) method was not yet as it is called from the EventExecutor of the // ChannelHandlerContext. If we not guard against this super.checkDeadLock() would cause an // IllegalStateException when trying to call executor(). return; } super.checkDeadLock(); } } }
/** * A {@link Timer} optimized for approximated I/O timeout scheduling. * * <h3>Tick Duration</h3> * * As described with 'approximated', this timer does not execute the scheduled {@link TimerTask} on * time. {@link HashedWheelTimer}, on every tick, will check if there are any {@link TimerTask}s * behind the schedule and execute them. * * <p>You can increase or decrease the accuracy of the execution timing by specifying smaller or * larger tick duration in the constructor. In most network applications, I/O timeout does not need * to be accurate. Therefore, the default tick duration is 100 milliseconds and you will not need to * try different configurations in most cases. * * <h3>Ticks per Wheel (Wheel Size)</h3> * * {@link HashedWheelTimer} maintains a data structure called 'wheel'. To put simply, a wheel is a * hash table of {@link TimerTask}s whose hash function is 'dead line of the task'. The default * number of ticks per wheel (i.e. the size of the wheel) is 512. You could specify a larger value * if you are going to schedule a lot of timeouts. * * <h3>Do not create many instances.</h3> * * {@link HashedWheelTimer} creates a new thread whenever it is instantiated and started. Therefore, * you should make sure to create only one instance and share it across your application. One of the * common mistakes, that makes your application unresponsive, is to create a new instance for every * connection. * * <h3>Implementation Details</h3> * * {@link HashedWheelTimer} is based on <a href="http://cseweb.ucsd.edu/users/varghese/">George * Varghese</a> and Tony Lauck's paper, <a * href="http://cseweb.ucsd.edu/users/varghese/PAPERS/twheel.ps.Z">'Hashed and Hierarchical Timing * Wheels: data structures to efficiently implement a timer facility'</a>. More comprehensive slides * are located <a href="http://www.cse.wustl.edu/~cdgill/courses/cs6874/TimingWheels.ppt">here</a>. */ public class HashedWheelTimer implements Timer { static final InternalLogger logger = InternalLoggerFactory.getInstance(HashedWheelTimer.class); private static final ResourceLeakDetector<HashedWheelTimer> leakDetector = new ResourceLeakDetector<HashedWheelTimer>( HashedWheelTimer.class, 1, Runtime.getRuntime().availableProcessors() * 4); private final ResourceLeak leak = leakDetector.open(this); private final Worker worker = new Worker(); final Thread workerThread; final AtomicInteger workerState = new AtomicInteger(); // 0 - init, 1 - started, 2 - shut down private final long roundDuration; final long tickDuration; final Set<HashedWheelTimeout>[] wheel; final int mask; final ReadWriteLock lock = new ReentrantReadWriteLock(); volatile int wheelCursor; /** * Creates a new timer with the default thread factory ({@link Executors#defaultThreadFactory()}), * default tick duration, and default number of ticks per wheel. */ public HashedWheelTimer() { this(Executors.defaultThreadFactory()); } /** * Creates a new timer with the default thread factory ({@link Executors#defaultThreadFactory()}) * and default number of ticks per wheel. * * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @throws NullPointerException if {@code unit} is {@code null} * @throws IllegalArgumentException if {@code tickDuration} is <= 0 */ public HashedWheelTimer(long tickDuration, TimeUnit unit) { this(Executors.defaultThreadFactory(), tickDuration, unit); } /** * Creates a new timer with the default thread factory ({@link Executors#defaultThreadFactory()}). * * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @param ticksPerWheel the size of the wheel * @throws NullPointerException if {@code unit} is {@code null} * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is * <= 0 */ public HashedWheelTimer(long tickDuration, TimeUnit unit, int ticksPerWheel) { this(Executors.defaultThreadFactory(), tickDuration, unit, ticksPerWheel); } /** * Creates a new timer with the default tick duration and default number of ticks per wheel. * * @param threadFactory a {@link ThreadFactory} that creates a background {@link Thread} which is * dedicated to {@link TimerTask} execution. * @throws NullPointerException if {@code threadFactory} is {@code null} */ public HashedWheelTimer(ThreadFactory threadFactory) { this(threadFactory, 100, TimeUnit.MILLISECONDS); } /** * Creates a new timer with the default number of ticks per wheel. * * @param threadFactory a {@link ThreadFactory} that creates a background {@link Thread} which is * dedicated to {@link TimerTask} execution. * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code * null} * @throws IllegalArgumentException if {@code tickDuration} is <= 0 */ public HashedWheelTimer(ThreadFactory threadFactory, long tickDuration, TimeUnit unit) { this(threadFactory, tickDuration, unit, 512); } /** * Creates a new timer. * * @param threadFactory a {@link ThreadFactory} that creates a background {@link Thread} which is * dedicated to {@link TimerTask} execution. * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @param ticksPerWheel the size of the wheel * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code * null} * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is * <= 0 */ public HashedWheelTimer( ThreadFactory threadFactory, long tickDuration, TimeUnit unit, int ticksPerWheel) { if (threadFactory == null) { throw new NullPointerException("threadFactory"); } if (unit == null) { throw new NullPointerException("unit"); } if (tickDuration <= 0) { throw new IllegalArgumentException("tickDuration must be greater than 0: " + tickDuration); } if (ticksPerWheel <= 0) { throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel); } // Normalize ticksPerWheel to power of two and initialize the wheel. wheel = createWheel(ticksPerWheel); mask = wheel.length - 1; // Convert tickDuration to milliseconds. this.tickDuration = tickDuration = unit.toMillis(tickDuration); // Prevent overflow. if (tickDuration == Long.MAX_VALUE || tickDuration >= Long.MAX_VALUE / wheel.length) { throw new IllegalArgumentException("tickDuration is too long: " + tickDuration + ' ' + unit); } roundDuration = tickDuration * wheel.length; workerThread = threadFactory.newThread(worker); } @SuppressWarnings("unchecked") private static Set<HashedWheelTimeout>[] createWheel(int ticksPerWheel) { if (ticksPerWheel <= 0) { throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel); } if (ticksPerWheel > 1073741824) { throw new IllegalArgumentException( "ticksPerWheel may not be greater than 2^30: " + ticksPerWheel); } ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel); Set<HashedWheelTimeout>[] wheel = new Set[ticksPerWheel]; for (int i = 0; i < wheel.length; i++) { wheel[i] = Collections.newSetFromMap( PlatformDependent.<HashedWheelTimeout, Boolean>newConcurrentHashMap()); } return wheel; } private static int normalizeTicksPerWheel(int ticksPerWheel) { int normalizedTicksPerWheel = 1; while (normalizedTicksPerWheel < ticksPerWheel) { normalizedTicksPerWheel <<= 1; } return normalizedTicksPerWheel; } /** * Starts the background thread explicitly. The background thread will start automatically on * demand even if you did not call this method. * * @throws IllegalStateException if this timer has been {@linkplain #stop() stopped} already */ public void start() { switch (workerState.get()) { case 0: if (workerState.compareAndSet(0, 1)) { workerThread.start(); } break; case 1: break; case 2: throw new IllegalStateException("cannot be started once stopped"); default: throw new Error(); } } @Override public Set<Timeout> stop() { if (Thread.currentThread() == workerThread) { throw new IllegalStateException( HashedWheelTimer.class.getSimpleName() + ".stop() cannot be called from " + TimerTask.class.getSimpleName()); } if (!workerState.compareAndSet(1, 2)) { // workerState can be 0 or 2 at this moment - let it always be 2. workerState.set(2); return Collections.emptySet(); } boolean interrupted = false; while (workerThread.isAlive()) { workerThread.interrupt(); try { workerThread.join(100); } catch (InterruptedException e) { interrupted = true; } } if (interrupted) { Thread.currentThread().interrupt(); } leak.close(); Set<Timeout> unprocessedTimeouts = new HashSet<Timeout>(); for (Set<HashedWheelTimeout> bucket : wheel) { unprocessedTimeouts.addAll(bucket); bucket.clear(); } return Collections.unmodifiableSet(unprocessedTimeouts); } @Override public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit) { final long currentTime = System.currentTimeMillis(); if (task == null) { throw new NullPointerException("task"); } if (unit == null) { throw new NullPointerException("unit"); } start(); delay = unit.toMillis(delay); HashedWheelTimeout timeout = new HashedWheelTimeout(task, currentTime + delay); scheduleTimeout(timeout, delay); return timeout; } void scheduleTimeout(HashedWheelTimeout timeout, long delay) { // delay must be equal to or greater than tickDuration so that the // worker thread never misses the timeout. if (delay < tickDuration) { delay = tickDuration; } // Prepare the required parameters to schedule the timeout object. final long lastRoundDelay = delay % roundDuration; final long lastTickDelay = delay % tickDuration; final long relativeIndex = lastRoundDelay / tickDuration + (lastTickDelay != 0 ? 1 : 0); final long remainingRounds = delay / roundDuration - (delay % roundDuration == 0 ? 1 : 0); // Add the timeout to the wheel. lock.readLock().lock(); try { int stopIndex = (int) (wheelCursor + relativeIndex & mask); timeout.stopIndex = stopIndex; timeout.remainingRounds = remainingRounds; wheel[stopIndex].add(timeout); } finally { lock.readLock().unlock(); } } private final class Worker implements Runnable { private long startTime; private long tick; Worker() {} @Override public void run() { List<HashedWheelTimeout> expiredTimeouts = new ArrayList<HashedWheelTimeout>(); startTime = System.currentTimeMillis(); tick = 1; while (workerState.get() == 1) { final long deadline = waitForNextTick(); if (deadline > 0) { fetchExpiredTimeouts(expiredTimeouts, deadline); notifyExpiredTimeouts(expiredTimeouts); } } } private void fetchExpiredTimeouts(List<HashedWheelTimeout> expiredTimeouts, long deadline) { // Find the expired timeouts and decrease the round counter // if necessary. Note that we don't send the notification // immediately to make sure the listeners are called without // an exclusive lock. lock.writeLock().lock(); try { int newWheelCursor = wheelCursor = wheelCursor + 1 & mask; fetchExpiredTimeouts(expiredTimeouts, wheel[newWheelCursor].iterator(), deadline); } finally { lock.writeLock().unlock(); } } private void fetchExpiredTimeouts( List<HashedWheelTimeout> expiredTimeouts, Iterator<HashedWheelTimeout> i, long deadline) { List<HashedWheelTimeout> slipped = null; while (i.hasNext()) { HashedWheelTimeout timeout = i.next(); if (timeout.remainingRounds <= 0) { i.remove(); if (timeout.deadline <= deadline) { expiredTimeouts.add(timeout); } else { // Handle the case where the timeout is put into a wrong // place, usually one tick earlier. For now, just add // it to a temporary list - we will reschedule it in a // separate loop. if (slipped == null) { slipped = new ArrayList<HashedWheelTimeout>(); } slipped.add(timeout); } } else { timeout.remainingRounds--; } } // Reschedule the slipped timeouts. if (slipped != null) { for (HashedWheelTimeout timeout : slipped) { scheduleTimeout(timeout, timeout.deadline - deadline); } } } private void notifyExpiredTimeouts(List<HashedWheelTimeout> expiredTimeouts) { // Notify the expired timeouts. for (int i = expiredTimeouts.size() - 1; i >= 0; i--) { expiredTimeouts.get(i).expire(); } // Clean up the temporary list. expiredTimeouts.clear(); } private long waitForNextTick() { long deadline = startTime + tickDuration * tick; for (; ; ) { final long currentTime = System.currentTimeMillis(); long sleepTime = tickDuration * tick - (currentTime - startTime); // Check if we run on windows, as if thats the case we will need // to round the sleepTime as workaround for a bug that only affect // the JVM if it runs on windows. // // See https://github.com/netty/netty/issues/356 if (PlatformDependent.isWindows()) { sleepTime = sleepTime / 10 * 10; } if (sleepTime <= 0) { break; } try { Thread.sleep(sleepTime); } catch (InterruptedException e) { if (workerState.get() != 1) { return -1; } } } // Increase the tick. tick++; return deadline; } } private final class HashedWheelTimeout implements Timeout { private static final int ST_INIT = 0; private static final int ST_CANCELLED = 1; private static final int ST_EXPIRED = 2; private final TimerTask task; final long deadline; volatile int stopIndex; volatile long remainingRounds; private final AtomicInteger state = new AtomicInteger(ST_INIT); HashedWheelTimeout(TimerTask task, long deadline) { this.task = task; this.deadline = deadline; } @Override public Timer timer() { return HashedWheelTimer.this; } @Override public TimerTask task() { return task; } @Override public boolean cancel() { if (!state.compareAndSet(ST_INIT, ST_CANCELLED)) { return false; } wheel[stopIndex].remove(this); return true; } @Override public boolean isCancelled() { return state.get() == ST_CANCELLED; } @Override public boolean isExpired() { return state.get() != ST_INIT; } public void expire() { if (!state.compareAndSet(ST_INIT, ST_EXPIRED)) { return; } try { task.run(this); } catch (Throwable t) { if (logger.isWarnEnabled()) { logger.warn("An exception was thrown by " + TimerTask.class.getSimpleName() + '.', t); } } } @Override public String toString() { long currentTime = System.currentTimeMillis(); long remaining = deadline - currentTime; StringBuilder buf = new StringBuilder(192); buf.append(getClass().getSimpleName()); buf.append('('); buf.append("deadline: "); if (remaining > 0) { buf.append(remaining); buf.append(" ms later, "); } else if (remaining < 0) { buf.append(-remaining); buf.append(" ms ago, "); } else { buf.append("now, "); } if (isCancelled()) { buf.append(", cancelled"); } return buf.append(')').toString(); } } }
/** * A {@link ChannelHandler} that aggregates an {@link HttpMessage} and its following {@link * HttpContent}s into a single {@link FullHttpRequest} or {@link FullHttpResponse} (depending on if * it used to handle requests or responses) with no following {@link HttpContent}s. It is useful * when you don't want to take care of HTTP messages whose transfer encoding is 'chunked'. Insert * this handler after {@link HttpObjectDecoder} in the {@link ChannelPipeline}: * * <pre> * {@link ChannelPipeline} p = ...; * ... * p.addLast("encoder", new {@link HttpResponseEncoder}()); * p.addLast("decoder", new {@link HttpRequestDecoder}()); * p.addLast("aggregator", <b>new {@link HttpObjectAggregator}(1048576)</b>); * ... * p.addLast("handler", new HttpRequestHandler()); * </pre> * * Be aware that you need to have the {@link HttpResponseEncoder} or {@link HttpRequestEncoder} * before the {@link HttpObjectAggregator} in the {@link ChannelPipeline}. */ public class HttpObjectAggregator extends MessageAggregator<HttpObject, HttpMessage, HttpContent, FullHttpMessage> { private static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpObjectAggregator.class); private static final FullHttpResponse CONTINUE = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); private static final FullHttpResponse EXPECTATION_FAILED = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.EXPECTATION_FAILED, Unpooled.EMPTY_BUFFER); private static final FullHttpResponse TOO_LARGE = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER); static { EXPECTATION_FAILED.headers().set(CONTENT_LENGTH, "0"); TOO_LARGE.headers().set(CONTENT_LENGTH, "0"); } private final boolean closeOnExpectationFailed; /** * Creates a new instance. * * @param maxContentLength the maximum length of the aggregated content in bytes. If the length of * the aggregated content exceeds this value, {@link * #handleOversizedMessage(ChannelHandlerContext, HttpMessage)} will be called. */ public HttpObjectAggregator(int maxContentLength) { this(maxContentLength, false); } /** * Creates a new instance. * * @param maxContentLength the maximum length of the aggregated content in bytes. If the length of * the aggregated content exceeds this value, {@link * #handleOversizedMessage(ChannelHandlerContext, HttpMessage)} will be called. * @param closeOnExpectationFailed If a 100-continue response is detected but the content length * is too large then {@code true} means close the connection. otherwise the connection will * remain open and data will be consumed and discarded until the next request is received. */ public HttpObjectAggregator(int maxContentLength, boolean closeOnExpectationFailed) { super(maxContentLength); this.closeOnExpectationFailed = closeOnExpectationFailed; } @Override protected boolean isStartMessage(HttpObject msg) throws Exception { return msg instanceof HttpMessage; } @Override protected boolean isContentMessage(HttpObject msg) throws Exception { return msg instanceof HttpContent; } @Override protected boolean isLastContentMessage(HttpContent msg) throws Exception { return msg instanceof LastHttpContent; } @Override protected boolean isAggregated(HttpObject msg) throws Exception { return msg instanceof FullHttpMessage; } @Override protected boolean isContentLengthInvalid(HttpMessage start, int maxContentLength) { return getContentLength(start, -1) > maxContentLength; } @Override protected Object newContinueResponse( HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { if (HttpUtil.is100ContinueExpected(start)) { if (getContentLength(start, -1) <= maxContentLength) { return CONTINUE.duplicate().retain(); } pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); return EXPECTATION_FAILED.duplicate().retain(); } return null; } @Override protected boolean closeAfterContinueResponse(Object msg) { return closeOnExpectationFailed && ignoreContentAfterContinueResponse(msg); } @Override protected boolean ignoreContentAfterContinueResponse(Object msg) { return msg instanceof HttpResponse && ((HttpResponse) msg).status().code() == HttpResponseStatus.EXPECTATION_FAILED.code(); } @Override protected FullHttpMessage beginAggregation(HttpMessage start, ByteBuf content) throws Exception { assert !(start instanceof FullHttpMessage); HttpUtil.setTransferEncodingChunked(start, false); AggregatedFullHttpMessage ret; if (start instanceof HttpRequest) { ret = new AggregatedFullHttpRequest((HttpRequest) start, content, null); } else if (start instanceof HttpResponse) { ret = new AggregatedFullHttpResponse((HttpResponse) start, content, null); } else { throw new Error(); } return ret; } @Override protected void aggregate(FullHttpMessage aggregated, HttpContent content) throws Exception { if (content instanceof LastHttpContent) { // Merge trailing headers into the message. ((AggregatedFullHttpMessage) aggregated) .setTrailingHeaders(((LastHttpContent) content).trailingHeaders()); } } @Override protected void finishAggregation(FullHttpMessage aggregated) throws Exception { // Set the 'Content-Length' header. If one isn't already set. // This is important as HEAD responses will use a 'Content-Length' header which // does not match the actual body, but the number of bytes that would be // transmitted if a GET would have been used. // // See rfc2616 14.13 Content-Length if (!HttpUtil.isContentLengthSet(aggregated)) { aggregated .headers() .set( HttpHeaderNames.CONTENT_LENGTH, String.valueOf(aggregated.content().readableBytes())); } } @Override protected void handleOversizedMessage(final ChannelHandlerContext ctx, HttpMessage oversized) throws Exception { if (oversized instanceof HttpRequest) { // send back a 413 and close the connection ChannelFuture future = ctx.writeAndFlush(TOO_LARGE.duplicate().retain()) .addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { logger.debug( "Failed to send a 413 Request Entity Too Large.", future.cause()); ctx.close(); } } }); // If the client started to send data already, close because it's impossible to recover. // If keep-alive is off and 'Expect: 100-continue' is missing, no need to leave the connection // open. if (oversized instanceof FullHttpMessage || !HttpUtil.is100ContinueExpected(oversized) && !HttpUtil.isKeepAlive(oversized)) { future.addListener(ChannelFutureListener.CLOSE); } // If an oversized request was handled properly and the connection is still alive // (i.e. rejected 100-continue). the decoder should prepare to handle a new message. HttpObjectDecoder decoder = ctx.pipeline().get(HttpObjectDecoder.class); if (decoder != null) { decoder.reset(); } } else if (oversized instanceof HttpResponse) { ctx.close(); throw new TooLongFrameException("Response entity too large: " + oversized); } else { throw new IllegalStateException(); } } private abstract static class AggregatedFullHttpMessage implements ByteBufHolder, FullHttpMessage { protected final HttpMessage message; private final ByteBuf content; private HttpHeaders trailingHeaders; AggregatedFullHttpMessage(HttpMessage message, ByteBuf content, HttpHeaders trailingHeaders) { this.message = message; this.content = content; this.trailingHeaders = trailingHeaders; } @Override public HttpHeaders trailingHeaders() { HttpHeaders trailingHeaders = this.trailingHeaders; if (trailingHeaders == null) { return EmptyHttpHeaders.INSTANCE; } else { return trailingHeaders; } } void setTrailingHeaders(HttpHeaders trailingHeaders) { this.trailingHeaders = trailingHeaders; } @Override public HttpVersion protocolVersion() { return message.protocolVersion(); } @Override public FullHttpMessage setProtocolVersion(HttpVersion version) { message.setProtocolVersion(version); return this; } @Override public HttpHeaders headers() { return message.headers(); } @Override public DecoderResult decoderResult() { return message.decoderResult(); } @Override public void setDecoderResult(DecoderResult result) { message.setDecoderResult(result); } @Override public ByteBuf content() { return content; } @Override public int refCnt() { return content.refCnt(); } @Override public FullHttpMessage retain() { content.retain(); return this; } @Override public FullHttpMessage retain(int increment) { content.retain(increment); return this; } @Override public FullHttpMessage touch(Object hint) { content.touch(hint); return this; } @Override public FullHttpMessage touch() { content.touch(); return this; } @Override public boolean release() { return content.release(); } @Override public boolean release(int decrement) { return content.release(decrement); } @Override public abstract FullHttpMessage copy(); @Override public abstract FullHttpMessage duplicate(); } private static final class AggregatedFullHttpRequest extends AggregatedFullHttpMessage implements FullHttpRequest { AggregatedFullHttpRequest(HttpRequest request, ByteBuf content, HttpHeaders trailingHeaders) { super(request, content, trailingHeaders); } /** * Copy this object * * @param copyContent * <ul> * <li>{@code true} if this object's {@link #content()} should be used to copy. * <li>{@code false} if {@code newContent} should be used instead. * </ul> * * @param newContent * <ul> * <li>if {@code copyContent} is false then this will be used in the copy's content. * <li>if {@code null} then a default buffer of 0 size will be selected * </ul> * * @return A copy of this object */ private FullHttpRequest copy(boolean copyContent, ByteBuf newContent) { DefaultFullHttpRequest copy = new DefaultFullHttpRequest( protocolVersion(), method(), uri(), copyContent ? content().copy() : newContent == null ? Unpooled.buffer(0) : newContent); copy.headers().set(headers()); copy.trailingHeaders().set(trailingHeaders()); return copy; } @Override public FullHttpRequest copy(ByteBuf newContent) { return copy(false, newContent); } @Override public FullHttpRequest copy() { return copy(true, null); } @Override public FullHttpRequest duplicate() { DefaultFullHttpRequest duplicate = new DefaultFullHttpRequest(protocolVersion(), method(), uri(), content().duplicate()); duplicate.headers().set(headers()); duplicate.trailingHeaders().set(trailingHeaders()); return duplicate; } @Override public FullHttpRequest retain(int increment) { super.retain(increment); return this; } @Override public FullHttpRequest retain() { super.retain(); return this; } @Override public FullHttpRequest touch() { super.touch(); return this; } @Override public FullHttpRequest touch(Object hint) { super.touch(hint); return this; } @Override public FullHttpRequest setMethod(HttpMethod method) { ((HttpRequest) message).setMethod(method); return this; } @Override public FullHttpRequest setUri(String uri) { ((HttpRequest) message).setUri(uri); return this; } @Override public HttpMethod method() { return ((HttpRequest) message).method(); } @Override public String uri() { return ((HttpRequest) message).uri(); } @Override public FullHttpRequest setProtocolVersion(HttpVersion version) { super.setProtocolVersion(version); return this; } @Override public String toString() { return HttpMessageUtil.appendFullRequest(new StringBuilder(256), this).toString(); } } private static final class AggregatedFullHttpResponse extends AggregatedFullHttpMessage implements FullHttpResponse { AggregatedFullHttpResponse(HttpResponse message, ByteBuf content, HttpHeaders trailingHeaders) { super(message, content, trailingHeaders); } /** * Copy this object * * @param copyContent * <ul> * <li>{@code true} if this object's {@link #content()} should be used to copy. * <li>{@code false} if {@code newContent} should be used instead. * </ul> * * @param newContent * <ul> * <li>if {@code copyContent} is false then this will be used in the copy's content. * <li>if {@code null} then a default buffer of 0 size will be selected * </ul> * * @return A copy of this object */ private FullHttpResponse copy(boolean copyContent, ByteBuf newContent) { DefaultFullHttpResponse copy = new DefaultFullHttpResponse( protocolVersion(), status(), copyContent ? content().copy() : newContent == null ? Unpooled.buffer(0) : newContent); copy.headers().set(headers()); copy.trailingHeaders().set(trailingHeaders()); return copy; } @Override public FullHttpResponse copy(ByteBuf newContent) { return copy(false, newContent); } @Override public FullHttpResponse copy() { return copy(true, null); } @Override public FullHttpResponse duplicate() { DefaultFullHttpResponse duplicate = new DefaultFullHttpResponse(protocolVersion(), status(), content().duplicate()); duplicate.headers().set(headers()); duplicate.trailingHeaders().set(trailingHeaders()); return duplicate; } @Override public FullHttpResponse setStatus(HttpResponseStatus status) { ((HttpResponse) message).setStatus(status); return this; } @Override public HttpResponseStatus status() { return ((HttpResponse) message).status(); } @Override public FullHttpResponse setProtocolVersion(HttpVersion version) { super.setProtocolVersion(version); return this; } @Override public FullHttpResponse retain(int increment) { super.retain(increment); return this; } @Override public FullHttpResponse retain() { super.retain(); return this; } @Override public FullHttpResponse touch(Object hint) { super.touch(hint); return this; } @Override public FullHttpResponse touch() { super.touch(); return this; } @Override public String toString() { return HttpMessageUtil.appendFullResponse(new StringBuilder(256), this).toString(); } } }
public class SpdySessionHandlerTest { private static final InternalLogger logger = InternalLoggerFactory.getInstance(SpdySessionHandlerTest.class); private static final int closeSignal = SpdyCodecUtil.SPDY_SETTINGS_MAX_ID; private static final SpdySettingsFrame closeMessage = new DefaultSpdySettingsFrame(); static { closeMessage.setValue(closeSignal, 0); } private static void assertDataFrame(Object msg, int streamId, boolean last) { assertNotNull(msg); assertTrue(msg instanceof SpdyDataFrame); SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg; assertEquals(spdyDataFrame.getStreamId(), streamId); assertEquals(spdyDataFrame.isLast(), last); } private static void assertSynReply(Object msg, int streamId, boolean last, SpdyHeaders headers) { assertNotNull(msg); assertTrue(msg instanceof SpdySynReplyFrame); assertHeaders(msg, streamId, last, headers); } private static void assertRstStream(Object msg, int streamId, SpdyStreamStatus status) { assertNotNull(msg); assertTrue(msg instanceof SpdyRstStreamFrame); SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg; assertEquals(spdyRstStreamFrame.getStreamId(), streamId); assertEquals(spdyRstStreamFrame.getStatus(), status); } private static void assertPing(Object msg, int id) { assertNotNull(msg); assertTrue(msg instanceof SpdyPingFrame); SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg; assertEquals(spdyPingFrame.getId(), id); } private static void assertGoAway(Object msg, int lastGoodStreamId) { assertNotNull(msg); assertTrue(msg instanceof SpdyGoAwayFrame); SpdyGoAwayFrame spdyGoAwayFrame = (SpdyGoAwayFrame) msg; assertEquals(spdyGoAwayFrame.getLastGoodStreamId(), lastGoodStreamId); } private static void assertHeaders(Object msg, int streamId, boolean last, SpdyHeaders headers) { assertNotNull(msg); assertTrue(msg instanceof SpdyHeadersFrame); SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg; assertEquals(spdyHeadersFrame.getStreamId(), streamId); assertEquals(spdyHeadersFrame.isLast(), last); for (String name : headers.names()) { List<String> expectedValues = headers.getAll(name); List<String> receivedValues = spdyHeadersFrame.headers().getAll(name); assertTrue(receivedValues.containsAll(expectedValues)); receivedValues.removeAll(expectedValues); assertTrue(receivedValues.isEmpty()); spdyHeadersFrame.headers().remove(name); } assertTrue(spdyHeadersFrame.headers().isEmpty()); } private static void testSpdySessionHandler(SpdyVersion version, boolean server) { EmbeddedChannel sessionHandler = new EmbeddedChannel( new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server)); while (sessionHandler.readOutbound() != null) { continue; } int localStreamId = server ? 1 : 2; int remoteStreamId = server ? 2 : 1; SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0); spdySynStreamFrame.headers().set("Compression", "test"); SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId); spdyDataFrame.setLast(true); // Check if session handler returns INVALID_STREAM if it receives // a data frame for a Stream-ID that is not open sessionHandler.writeInbound(new DefaultSpdyDataFrame(localStreamId)); assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.INVALID_STREAM); assertNull(sessionHandler.readOutbound()); // Check if session handler returns PROTOCOL_ERROR if it receives // a data frame for a Stream-ID before receiving a SYN_REPLY frame sessionHandler.writeInbound(new DefaultSpdyDataFrame(remoteStreamId)); assertRstStream(sessionHandler.readOutbound(), remoteStreamId, SpdyStreamStatus.PROTOCOL_ERROR); assertNull(sessionHandler.readOutbound()); remoteStreamId += 2; // Check if session handler returns PROTOCOL_ERROR if it receives // multiple SYN_REPLY frames for the same active Stream-ID sessionHandler.writeInbound(new DefaultSpdySynReplyFrame(remoteStreamId)); assertNull(sessionHandler.readOutbound()); sessionHandler.writeInbound(new DefaultSpdySynReplyFrame(remoteStreamId)); assertRstStream(sessionHandler.readOutbound(), remoteStreamId, SpdyStreamStatus.STREAM_IN_USE); assertNull(sessionHandler.readOutbound()); remoteStreamId += 2; // Check if frame codec correctly compresses/uncompresses headers sessionHandler.writeInbound(spdySynStreamFrame); assertSynReply( sessionHandler.readOutbound(), localStreamId, false, spdySynStreamFrame.headers()); assertNull(sessionHandler.readOutbound()); SpdyHeadersFrame spdyHeadersFrame = new DefaultSpdyHeadersFrame(localStreamId); spdyHeadersFrame.headers().add("HEADER", "test1"); spdyHeadersFrame.headers().add("HEADER", "test2"); sessionHandler.writeInbound(spdyHeadersFrame); assertHeaders(sessionHandler.readOutbound(), localStreamId, false, spdyHeadersFrame.headers()); assertNull(sessionHandler.readOutbound()); localStreamId += 2; // Check if session handler closed the streams using the number // of concurrent streams and that it returns REFUSED_STREAM // if it receives a SYN_STREAM frame it does not wish to accept spdySynStreamFrame.setStreamId(localStreamId); spdySynStreamFrame.setLast(true); spdySynStreamFrame.setUnidirectional(true); sessionHandler.writeInbound(spdySynStreamFrame); assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.REFUSED_STREAM); assertNull(sessionHandler.readOutbound()); // Check if session handler rejects HEADERS for closed streams int testStreamId = spdyDataFrame.getStreamId(); sessionHandler.writeInbound(spdyDataFrame); assertDataFrame(sessionHandler.readOutbound(), testStreamId, spdyDataFrame.isLast()); assertNull(sessionHandler.readOutbound()); spdyHeadersFrame.setStreamId(testStreamId); sessionHandler.writeInbound(spdyHeadersFrame); assertRstStream(sessionHandler.readOutbound(), testStreamId, SpdyStreamStatus.INVALID_STREAM); assertNull(sessionHandler.readOutbound()); // Check if session handler drops active streams if it receives // a RST_STREAM frame for that Stream-ID sessionHandler.writeInbound(new DefaultSpdyRstStreamFrame(remoteStreamId, 3)); assertNull(sessionHandler.readOutbound()); remoteStreamId += 2; // Check if session handler honors UNIDIRECTIONAL streams spdySynStreamFrame.setLast(false); sessionHandler.writeInbound(spdySynStreamFrame); assertNull(sessionHandler.readOutbound()); spdySynStreamFrame.setUnidirectional(false); // Check if session handler returns PROTOCOL_ERROR if it receives // multiple SYN_STREAM frames for the same active Stream-ID sessionHandler.writeInbound(spdySynStreamFrame); assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR); assertNull(sessionHandler.readOutbound()); localStreamId += 2; // Check if session handler returns PROTOCOL_ERROR if it receives // a SYN_STREAM frame with an invalid Stream-ID spdySynStreamFrame.setStreamId(localStreamId - 1); sessionHandler.writeInbound(spdySynStreamFrame); assertRstStream( sessionHandler.readOutbound(), localStreamId - 1, SpdyStreamStatus.PROTOCOL_ERROR); assertNull(sessionHandler.readOutbound()); spdySynStreamFrame.setStreamId(localStreamId); // Check if session handler returns PROTOCOL_ERROR if it receives // an invalid HEADERS frame spdyHeadersFrame.setStreamId(localStreamId); spdyHeadersFrame.setInvalid(); sessionHandler.writeInbound(spdyHeadersFrame); assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR); assertNull(sessionHandler.readOutbound()); sessionHandler.finish(); } private static void testSpdySessionHandlerPing(SpdyVersion version, boolean server) { EmbeddedChannel sessionHandler = new EmbeddedChannel( new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server)); while (sessionHandler.readOutbound() != null) { continue; } int localStreamId = server ? 1 : 2; int remoteStreamId = server ? 2 : 1; SpdyPingFrame localPingFrame = new DefaultSpdyPingFrame(localStreamId); SpdyPingFrame remotePingFrame = new DefaultSpdyPingFrame(remoteStreamId); // Check if session handler returns identical local PINGs sessionHandler.writeInbound(localPingFrame); assertPing(sessionHandler.readOutbound(), localPingFrame.getId()); assertNull(sessionHandler.readOutbound()); // Check if session handler ignores un-initiated remote PINGs sessionHandler.writeInbound(remotePingFrame); assertNull(sessionHandler.readOutbound()); sessionHandler.finish(); } private static void testSpdySessionHandlerGoAway(SpdyVersion version, boolean server) { EmbeddedChannel sessionHandler = new EmbeddedChannel( new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server)); while (sessionHandler.readOutbound() != null) { continue; } int localStreamId = server ? 1 : 2; SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0); spdySynStreamFrame.headers().set("Compression", "test"); SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId); spdyDataFrame.setLast(true); // Send an initial request sessionHandler.writeInbound(spdySynStreamFrame); assertSynReply( sessionHandler.readOutbound(), localStreamId, false, spdySynStreamFrame.headers()); assertNull(sessionHandler.readOutbound()); sessionHandler.writeInbound(spdyDataFrame); assertDataFrame(sessionHandler.readOutbound(), localStreamId, true); assertNull(sessionHandler.readOutbound()); // Check if session handler sends a GOAWAY frame when closing sessionHandler.writeInbound(closeMessage); assertGoAway(sessionHandler.readOutbound(), localStreamId); assertNull(sessionHandler.readOutbound()); localStreamId += 2; // Check if session handler returns REFUSED_STREAM if it receives // SYN_STREAM frames after sending a GOAWAY frame spdySynStreamFrame.setStreamId(localStreamId); sessionHandler.writeInbound(spdySynStreamFrame); assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.REFUSED_STREAM); assertNull(sessionHandler.readOutbound()); // Check if session handler ignores Data frames after sending // a GOAWAY frame spdyDataFrame.setStreamId(localStreamId); sessionHandler.writeInbound(spdyDataFrame); assertNull(sessionHandler.readOutbound()); sessionHandler.finish(); } @Test public void testSpdyClientSessionHandler() { logger.info("Running: testSpdyClientSessionHandler v3.1"); testSpdySessionHandler(SpdyVersion.SPDY_3_1, false); } @Test public void testSpdyClientSessionHandlerPing() { logger.info("Running: testSpdyClientSessionHandlerPing v3.1"); testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, false); } @Test public void testSpdyClientSessionHandlerGoAway() { logger.info("Running: testSpdyClientSessionHandlerGoAway v3.1"); testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, false); } @Test public void testSpdyServerSessionHandler() { logger.info("Running: testSpdyServerSessionHandler v3.1"); testSpdySessionHandler(SpdyVersion.SPDY_3_1, true); } @Test public void testSpdyServerSessionHandlerPing() { logger.info("Running: testSpdyServerSessionHandlerPing v3.1"); testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, true); } @Test public void testSpdyServerSessionHandlerGoAway() { logger.info("Running: testSpdyServerSessionHandlerGoAway v3.1"); testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, true); } // Echo Handler opens 4 half-closed streams on session connection // and then sets the number of concurrent streams to 1 private static class EchoHandler extends ChannelHandlerAdapter { private final int closeSignal; private final boolean server; EchoHandler(int closeSignal, boolean server) { this.closeSignal = closeSignal; this.server = server; } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { // Initiate 4 new streams int streamId = server ? 2 : 1; SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(streamId, 0, (byte) 0); spdySynStreamFrame.setLast(true); ctx.writeAndFlush(spdySynStreamFrame); spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2); ctx.writeAndFlush(spdySynStreamFrame); spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2); ctx.writeAndFlush(spdySynStreamFrame); spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2); ctx.writeAndFlush(spdySynStreamFrame); // Limit the number of concurrent streams to 1 SpdySettingsFrame spdySettingsFrame = new DefaultSpdySettingsFrame(); spdySettingsFrame.setValue(SpdySettingsFrame.SETTINGS_MAX_CONCURRENT_STREAMS, 1); ctx.writeAndFlush(spdySettingsFrame); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof SpdySynStreamFrame) { SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg; if (!spdySynStreamFrame.isUnidirectional()) { int streamId = spdySynStreamFrame.getStreamId(); SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId); spdySynReplyFrame.setLast(spdySynStreamFrame.isLast()); for (Map.Entry<String, String> entry : spdySynStreamFrame.headers()) { spdySynReplyFrame.headers().add(entry.getKey(), entry.getValue()); } ctx.writeAndFlush(spdySynReplyFrame); } return; } if (msg instanceof SpdySynReplyFrame) { return; } if (msg instanceof SpdyDataFrame || msg instanceof SpdyPingFrame || msg instanceof SpdyHeadersFrame) { ctx.writeAndFlush(msg); return; } if (msg instanceof SpdySettingsFrame) { SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg; if (spdySettingsFrame.isSet(closeSignal)) { ctx.close(); } } } } }
/** * 微信请求处理类 * * @className WeixinRequestHandler * @author jy * @date 2014年11月16日 * @since JDK 1.7 * @see com.foxinmy.weixin4j.dispatcher.WeixinMessageDispatcher */ public class WeixinRequestHandler extends SimpleChannelInboundHandler<WeixinRequest> { private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass()); private final WeixinMessageDispatcher messageDispatcher; public WeixinRequestHandler(WeixinMessageDispatcher messageDispatcher) { this.messageDispatcher = messageDispatcher; } public void channelReadComplete(ChannelHandlerContext ctx) { ctx.flush(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { ctx.close(); logger.error("catch the exception:", cause); } @Override protected void channelRead0(ChannelHandlerContext ctx, WeixinRequest request) throws WeixinException { final AesToken aesToken = request.getAesToken(); if (aesToken == null || (StringUtil.isBlank(request.getSignature()) && StringUtil.isBlank(request.getMsgSignature()))) { ctx.writeAndFlush(HttpUtil.createHttpResponse(BAD_REQUEST)) .addListener(ChannelFutureListener.CLOSE); return; } /** 公众平台:无论Get,Post都带signature参数,当开启aes模式时带msg_signature参数 企业号:无论Get,Post都带msg_signature参数 */ if (request.getMethod().equals(HttpMethod.GET.name())) { if (!StringUtil.isBlank(request.getSignature()) && MessageUtil.signature(aesToken.getToken(), request.getTimeStamp(), request.getNonce()) .equals(request.getSignature())) { ctx.write(new SingleResponse(request.getEchoStr())); return; } if (!StringUtil.isBlank(request.getMsgSignature()) && MessageUtil.signature( aesToken.getToken(), request.getTimeStamp(), request.getNonce(), request.getEchoStr()) .equals(request.getMsgSignature())) { ctx.write( new SingleResponse( MessageUtil.aesDecrypt(null, aesToken.getAesKey(), request.getEchoStr()))); return; } ctx.writeAndFlush(HttpUtil.createHttpResponse(FORBIDDEN)) .addListener(ChannelFutureListener.CLOSE); return; } else if (request.getMethod().equals(HttpMethod.POST.name())) { if (!StringUtil.isBlank(request.getSignature()) && !MessageUtil.signature(aesToken.getToken(), request.getTimeStamp(), request.getNonce()) .equals(request.getSignature())) { ctx.writeAndFlush(HttpUtil.createHttpResponse(FORBIDDEN)) .addListener(ChannelFutureListener.CLOSE); return; } if (request.getEncryptType() == EncryptType.AES && !MessageUtil.signature( aesToken.getToken(), request.getTimeStamp(), request.getNonce(), request.getEncryptContent()) .equals(request.getMsgSignature())) { ctx.writeAndFlush(HttpUtil.createHttpResponse(FORBIDDEN)) .addListener(ChannelFutureListener.CLOSE); return; } } else { ctx.writeAndFlush(HttpUtil.createHttpResponse(METHOD_NOT_ALLOWED)) .addListener(ChannelFutureListener.CLOSE); return; } WeixinMessageTransfer messageTransfer = MessageTransferHandler.parser(request); ctx.channel().attr(Consts.MESSAGE_TRANSFER_KEY).set(messageTransfer); messageDispatcher.doDispatch(ctx, request, messageTransfer); } }
/** * Adds <a href="http://en.wikipedia.org/wiki/Transport_Layer_Security">SSL · TLS</a> and * StartTLS support to a {@link Channel}. Please refer to the <strong>"SecureChat"</strong> example * in the distribution or the web site for the detailed usage. * * <h3>Beginning the handshake</h3> * * <p>You must make sure not to write a message while the handshake is in progress unless you are * renegotiating. You will be notified by the {@link Future} which is returned by the {@link * #handshakeFuture()} method when the handshake process succeeds or fails. * * <p>Beside using the handshake {@link ChannelFuture} to get notified about the completation of the * handshake it's also possible to detect it by implement the {@link * ChannelHandler#userEventTriggered(ChannelHandlerContext, Object)} method and check for a {@link * SslHandshakeCompletionEvent}. * * <h3>Handshake</h3> * * <p>The handshake will be automaticly issued for you once the {@link Channel} is active and {@link * SSLEngine#getUseClientMode()} returns {@code true}. So no need to bother with it by your self. * * <h3>Closing the session</h3> * * <p>To close the SSL session, the {@link #close()} method should be called to send the {@code * close_notify} message to the remote peer. One exception is when you close the {@link Channel} - * {@link SslHandler} intercepts the close request and send the {@code close_notify} message before * the channel closure automatically. Once the SSL session is closed, it is not reusable, and * consequently you should create a new {@link SslHandler} with a new {@link SSLEngine} as explained * in the following section. * * <h3>Restarting the session</h3> * * <p>To restart the SSL session, you must remove the existing closed {@link SslHandler} from the * {@link ChannelPipeline}, insert a new {@link SslHandler} with a new {@link SSLEngine} into the * pipeline, and start the handshake process as described in the first section. * * <h3>Implementing StartTLS</h3> * * <p><a href="http://en.wikipedia.org/wiki/STARTTLS">StartTLS</a> is the communication pattern that * secures the wire in the middle of the plaintext connection. Please note that it is different from * SSL · TLS, that secures the wire from the beginning of the connection. Typically, StartTLS * is composed of three steps: * * <ol> * <li>Client sends a StartTLS request to server. * <li>Server sends a StartTLS response to client. * <li>Client begins SSL handshake. * </ol> * * If you implement a server, you need to: * * <ol> * <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code true}, * <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and * <li>write a StartTLS response. * </ol> * * Please note that you must insert {@link SslHandler} <em>before</em> sending the StartTLS * response. Otherwise the client can send begin SSL handshake before {@link SslHandler} is inserted * to the {@link ChannelPipeline}, causing data corruption. * * <p>The client-side implementation is much simpler. * * <ol> * <li>Write a StartTLS request, * <li>wait for the StartTLS response, * <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code false}, * <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and * <li>Initiate SSL handshake. * </ol> * * <h3>Known issues</h3> * * <p>Because of a known issue with the current implementation of the SslEngine that comes with Java * it may be possible that you see blocked IO-Threads while a full GC is done. * * <p>So if you are affected you can workaround this problem by adjust the cache settings like shown * below: * * <pre> * SslContext context = ...; * context.getServerSessionContext().setSessionCacheSize(someSaneSize); * context.getServerSessionContext().setSessionTime(someSameTimeout); * </pre> * * <p>What values to use here depends on the nature of your application and should be set based on * monitoring and debugging of it. For more details see <a * href="https://github.com/netty/netty/issues/832">#832</a> in our issue tracker. */ public class SslHandler extends ByteToMessageDecoder { private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslHandler.class); private static final Pattern IGNORABLE_CLASS_IN_STACK = Pattern.compile("^.*(?:Socket|Datagram|Sctp|Udt)Channel.*$"); private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile( "^.*(?:connection.*(?:reset|closed|abort|broken)|broken.*pipe).*$", Pattern.CASE_INSENSITIVE); private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already"); private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out"); private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException(); static { SSLENGINE_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); HANDSHAKE_TIMED_OUT.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); CHANNEL_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); } private volatile ChannelHandlerContext ctx; private final SSLEngine engine; private final int maxPacketBufferSize; private final boolean startTls; private boolean sentFirstMessage; private boolean flushedBeforeHandshakeDone; private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>(); /** * Set by wrap*() methods when something is produced. {@link * #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call * ctx.flush(). */ private boolean needsFlush; private int packetLength; private ByteBuf decodeOut; private volatile long handshakeTimeoutMillis = 10000; private volatile long closeNotifyTimeoutMillis = 3000; /** * Creates a new instance. * * @param engine the {@link SSLEngine} this handler will use */ public SslHandler(SSLEngine engine) { this(engine, false); } /** * Creates a new instance. * * @param engine the {@link SSLEngine} this handler will use * @param startTls {@code true} if the first write request shouldn't be encrypted by the {@link * SSLEngine} */ public SslHandler(SSLEngine engine, boolean startTls) { if (engine == null) { throw new NullPointerException("engine"); } this.engine = engine; this.startTls = startTls; maxPacketBufferSize = engine.getSession().getPacketBufferSize(); } public long getHandshakeTimeoutMillis() { return handshakeTimeoutMillis; } public void setHandshakeTimeout(long handshakeTimeout, TimeUnit unit) { if (unit == null) { throw new NullPointerException("unit"); } setHandshakeTimeoutMillis(unit.toMillis(handshakeTimeout)); } public void setHandshakeTimeoutMillis(long handshakeTimeoutMillis) { if (handshakeTimeoutMillis < 0) { throw new IllegalArgumentException( "handshakeTimeoutMillis: " + handshakeTimeoutMillis + " (expected: >= 0)"); } this.handshakeTimeoutMillis = handshakeTimeoutMillis; } public long getCloseNotifyTimeoutMillis() { return closeNotifyTimeoutMillis; } public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) { if (unit == null) { throw new NullPointerException("unit"); } setCloseNotifyTimeoutMillis(unit.toMillis(closeNotifyTimeout)); } public void setCloseNotifyTimeoutMillis(long closeNotifyTimeoutMillis) { if (closeNotifyTimeoutMillis < 0) { throw new IllegalArgumentException( "closeNotifyTimeoutMillis: " + closeNotifyTimeoutMillis + " (expected: >= 0)"); } this.closeNotifyTimeoutMillis = closeNotifyTimeoutMillis; } /** Returns the {@link SSLEngine} which is used by this handler. */ public SSLEngine engine() { return engine; } /** Returns a {@link Future} that will get notified once the handshake completes. */ public Future<Channel> handshakeFuture() { return handshakePromise; } /** * Sends an SSL {@code close_notify} message to the specified channel and destroys the underlying * {@link SSLEngine}. */ public ChannelFuture close() { return close(ctx.newPromise()); } /** See {@link #close()} */ public ChannelFuture close(final ChannelPromise future) { final ChannelHandlerContext ctx = this.ctx; ctx.executor() .execute( new Runnable() { @Override public void run() { engine.closeOutbound(); try { write(ctx, Unpooled.EMPTY_BUFFER, future); flush(ctx); } catch (Exception e) { if (!future.tryFailure(e)) { logger.warn("flush() raised a masked exception.", e); } } } }); return future; } /** * Return the {@link ChannelFuture} that will get notified if the inbound of the {@link SSLEngine} * will get closed. * * <p>This method will return the same {@link ChannelFuture} all the time. * * <p>For more informations see the apidocs of {@link SSLEngine} */ public Future<Channel> sslCloseFuture() { return sslCloseFuture; } @Override public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { if (decodeOut != null) { decodeOut.release(); decodeOut = null; } for (; ; ) { PendingWrite write = pendingUnencryptedWrites.poll(); if (write == null) { break; } write.failAndRecycle(new ChannelException("Pending write on removal of SslHandler")); } } @Override public void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { closeOutboundAndChannel(ctx, promise, true); } @Override public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { closeOutboundAndChannel(ctx, promise, false); } @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { pendingUnencryptedWrites.add(PendingWrite.newInstance(msg, promise)); } @Override public void flush(ChannelHandlerContext ctx) throws Exception { // Do not encrypt the first write request if this handler is // created with startTLS flag turned on. if (startTls && !sentFirstMessage) { sentFirstMessage = true; for (; ; ) { PendingWrite pendingWrite = pendingUnencryptedWrites.poll(); if (pendingWrite == null) { break; } ctx.write(pendingWrite.msg(), (ChannelPromise) pendingWrite.recycleAndGet()); } ctx.flush(); return; } if (pendingUnencryptedWrites.isEmpty()) { pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null)); } if (!handshakePromise.isDone()) { flushedBeforeHandshakeDone = true; } wrap(ctx, false); ctx.flush(); } private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { ByteBuf out = null; ChannelPromise promise = null; try { for (; ; ) { PendingWrite pending = pendingUnencryptedWrites.peek(); if (pending == null) { break; } if (out == null) { out = ctx.alloc().buffer(maxPacketBufferSize); } if (!(pending.msg() instanceof ByteBuf)) { ctx.write(pending.msg(), (ChannelPromise) pending.recycleAndGet()); pendingUnencryptedWrites.remove(); continue; } ByteBuf buf = (ByteBuf) pending.msg(); SSLEngineResult result = wrap(engine, buf, out); if (!buf.isReadable()) { buf.release(); promise = (ChannelPromise) pending.recycleAndGet(); pendingUnencryptedWrites.remove(); } else { promise = null; } if (result.getStatus() == Status.CLOSED) { // SSLEngine has been closed already. // Any further write attempts should be denied. for (; ; ) { PendingWrite w = pendingUnencryptedWrites.poll(); if (w == null) { break; } w.failAndRecycle(SSLENGINE_CLOSED); } return; } else { switch (result.getHandshakeStatus()) { case NEED_TASK: runDelegatedTasks(); break; case FINISHED: setHandshakeSuccess(); // deliberate fall-through case NOT_HANDSHAKING: setHandshakeSuccessIfStillHandshaking(); // deliberate fall-through case NEED_WRAP: finishWrap(ctx, out, promise, inUnwrap); promise = null; out = null; break; case NEED_UNWRAP: return; default: throw new IllegalStateException( "Unknown handshake status: " + result.getHandshakeStatus()); } } } } catch (SSLException e) { setHandshakeFailure(e); throw e; } finally { finishWrap(ctx, out, promise, inUnwrap); } } private void finishWrap( ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) { if (out == null) { out = Unpooled.EMPTY_BUFFER; } else if (!out.isReadable()) { out.release(); out = Unpooled.EMPTY_BUFFER; } if (promise != null) { ctx.write(out, promise); } else { ctx.write(out); } if (inUnwrap) { needsFlush = true; } } private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { ByteBuf out = null; try { for (; ; ) { if (out == null) { out = ctx.alloc().buffer(maxPacketBufferSize); } SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out); if (result.bytesProduced() > 0) { ctx.write(out); if (inUnwrap) { needsFlush = true; } out = null; } switch (result.getHandshakeStatus()) { case FINISHED: setHandshakeSuccess(); break; case NEED_TASK: runDelegatedTasks(); break; case NEED_UNWRAP: if (!inUnwrap) { unwrapNonApp(ctx); } break; case NEED_WRAP: break; case NOT_HANDSHAKING: setHandshakeSuccessIfStillHandshaking(); // Workaround for TLS False Start problem reported at: // https://github.com/netty/netty/issues/1108#issuecomment-14266970 if (!inUnwrap) { unwrapNonApp(ctx); } break; default: throw new IllegalStateException( "Unknown handshake status: " + result.getHandshakeStatus()); } if (result.bytesProduced() == 0) { break; } } } catch (SSLException e) { setHandshakeFailure(e); throw e; } finally { if (out != null) { out.release(); } } } private SSLEngineResult wrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException { ByteBuffer in0 = in.nioBuffer(); for (; ; ) { ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); SSLEngineResult result = engine.wrap(in0, out0); in.skipBytes(result.bytesConsumed()); out.writerIndex(out.writerIndex() + result.bytesProduced()); switch (result.getStatus()) { case BUFFER_OVERFLOW: out.ensureWritable(maxPacketBufferSize); break; default: return result; } } } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { // Make sure to release SSLEngine, // and notify the handshake future if the connection has been closed during handshake. setHandshakeFailure(CHANNEL_CLOSED); super.channelInactive(ctx); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (ignoreException(cause)) { // It is safe to ignore the 'connection reset by peer' or // 'broken pipe' error after sending close_notify. if (logger.isDebugEnabled()) { logger.debug( "Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " + "while writing close_notify in response to the peer's close_notify", cause); } // Close the connection explicitly just in case the transport // did not close the connection automatically. if (ctx.channel().isActive()) { ctx.close(); } } else { ctx.fireExceptionCaught(cause); } } /** * Checks if the given {@link Throwable} can be ignore and just "swallowed" * * <p>When an ssl connection is closed a close_notify message is sent. After that the peer also * sends close_notify however, it's not mandatory to receive the close_notify. The party who sent * the initial close_notify can close the connection immediately then the peer will get connection * reset error. */ private boolean ignoreException(Throwable t) { if (!(t instanceof SSLException) && t instanceof IOException && sslCloseFuture.isDone()) { String message = String.valueOf(t.getMessage()).toLowerCase(); // first try to match connection reset / broke peer based on the regex. This is the fastest // way // but may fail on different jdk impls or OS's if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { return true; } // Inspect the StackTraceElements to see if it was a connection reset / broken pipe or not StackTraceElement[] elements = t.getStackTrace(); for (StackTraceElement element : elements) { String classname = element.getClassName(); String methodname = element.getMethodName(); // skip all classes that belong to the io.netty package if (classname.startsWith("io.netty.")) { continue; } // check if the method name is read if not skip it if (!"read".equals(methodname)) { continue; } // This will also match against SocketInputStream which is used by openjdk 7 and maybe // also others if (IGNORABLE_CLASS_IN_STACK.matcher(classname).matches()) { return true; } try { // No match by now.. Try to load the class via classloader and inspect it. // This is mainly done as other JDK implementations may differ in name of // the impl. Class<?> clazz = PlatformDependent.getClassLoader(getClass()).loadClass(classname); if (SocketChannel.class.isAssignableFrom(clazz) || DatagramChannel.class.isAssignableFrom(clazz)) { return true; } // also match against SctpChannel via String matching as it may not present. if (PlatformDependent.javaVersion() >= 7 && "com.sun.nio.sctp.SctpChannel".equals(clazz.getSuperclass().getName())) { return true; } } catch (ClassNotFoundException e) { // This should not happen just ignore } } } return false; } /** * Returns {@code true} if the given {@link ByteBuf} is encrypted. Be aware that this method will * not increase the readerIndex of the given {@link ByteBuf}. * * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to * read, otherwise it will throw an {@link IllegalArgumentException}. * @return encrypted {@code true} if the {@link ByteBuf} is encrypted, {@code false} otherwise. * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5 * bytes to read. */ public static boolean isEncrypted(ByteBuf buffer) { if (buffer.readableBytes() < 5) { throw new IllegalArgumentException("buffer must have at least 5 readable bytes"); } return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1; } /** * Return how much bytes can be read out of the encrypted data. Be aware that this method will not * increase the readerIndex of the given {@link ByteBuf}. * * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to * read, otherwise it will throw an {@link IllegalArgumentException}. * @return length The length of the encrypted packet that is included in the buffer. This will * return {@code -1} if the given {@link ByteBuf} is not encrypted at all. * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5 * bytes to read. */ private static int getEncryptedPacketLength(ByteBuf buffer, int offset) { int packetLength = 0; // SSLv3 or TLS - Check ContentType boolean tls; switch (buffer.getUnsignedByte(offset)) { case 20: // change_cipher_spec case 21: // alert case 22: // handshake case 23: // application_data tls = true; break; default: // SSLv2 or bad data tls = false; } if (tls) { // SSLv3 or TLS - Check ProtocolVersion int majorVersion = buffer.getUnsignedByte(offset + 1); if (majorVersion == 3) { // SSLv3 or TLS packetLength = buffer.getUnsignedShort(offset + 3) + 5; if (packetLength <= 5) { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; } } else { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; } } if (!tls) { // SSLv2 or bad data - Check the version boolean sslv2 = true; int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3; int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1); if (majorVersion == 2 || majorVersion == 3) { // SSLv2 if (headerLength == 2) { packetLength = (buffer.getShort(offset) & 0x7FFF) + 2; } else { packetLength = (buffer.getShort(offset) & 0x3FFF) + 3; } if (packetLength <= headerLength) { sslv2 = false; } } else { sslv2 = false; } if (!sslv2) { return -1; } } return packetLength; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { // Keeps the list of the length of every SSL record in the input buffer. int[] recordLengths = null; int nRecords = 0; final int startOffset = in.readerIndex(); final int endOffset = in.writerIndex(); int offset = startOffset; // If we calculated the length of the current SSL record before, use that information. if (packetLength > 0) { if (endOffset - startOffset < packetLength) { return; } else { recordLengths = new int[4]; recordLengths[0] = packetLength; nRecords = 1; offset += packetLength; packetLength = 0; } } boolean nonSslRecord = false; for (; ; ) { final int readableBytes = endOffset - offset; if (readableBytes < 5) { break; } final int packetLength = getEncryptedPacketLength(in, offset); if (packetLength == -1) { nonSslRecord = true; break; } assert packetLength > 0; if (packetLength > readableBytes) { // wait until the whole packet can be read this.packetLength = packetLength; break; } // We have a whole packet. // Remember the length of the current packet. if (recordLengths == null) { recordLengths = new int[4]; } if (nRecords == recordLengths.length) { recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1); } recordLengths[nRecords++] = packetLength; // Increment the offset to handle the next packet. offset += packetLength; } final int totalLength = offset - startOffset; if (totalLength > 0) { // The buffer contains one or more full SSL records. // Slice out the whole packet so unwrap will only be called with complete packets. // Also directly reset the packetLength. This is needed as unwrap(..) may trigger // decode(...) again via: // 1) unwrap(..) is called // 2) wrap(...) is called from within unwrap(...) // 3) wrap(...) calls unwrapLater(...) // 4) unwrapLater(...) calls decode(...) // // See https://github.com/netty/netty/issues/1534 in.skipBytes(totalLength); ByteBuffer buffer = in.nioBuffer(startOffset, totalLength); unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out); } if (nonSslRecord) { // Not an SSL/TLS packet NotSslRecordException e = new NotSslRecordException("not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(in.readableBytes()); ctx.fireExceptionCaught(e); setHandshakeFailure(e); } } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { if (needsFlush) { needsFlush = false; ctx.flush(); } super.channelReadComplete(ctx); } /** * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle * handshakes, etc. */ private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException { try { unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0); } finally { ByteBuf decodeOut = this.decodeOut; if (decodeOut != null && decodeOut.isReadable()) { this.decodeOut = null; ctx.fireChannelRead(decodeOut); } } } /** Unwraps multiple inbound SSL records. */ private void unwrapMultiple( ChannelHandlerContext ctx, ByteBuffer packet, int totalLength, int[] recordLengths, int nRecords, List<Object> out) throws SSLException { for (int i = 0; i < nRecords; i++) { packet.limit(packet.position() + recordLengths[i]); try { unwrapSingle(ctx, packet, totalLength); assert !packet.hasRemaining(); } finally { ByteBuf decodeOut = this.decodeOut; if (decodeOut != null && decodeOut.isReadable()) { this.decodeOut = null; out.add(decodeOut); } } } } /** Unwraps a single SSL record. */ private void unwrapSingle( ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { boolean wrapLater = false; try { for (; ; ) { if (decodeOut == null) { decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity); } final SSLEngineResult result = unwrap(engine, packet, decodeOut); final Status status = result.getStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final int produced = result.bytesProduced(); final int consumed = result.bytesConsumed(); if (status == Status.CLOSED) { // notify about the CLOSED state of the SSLEngine. See #137 sslCloseFuture.trySuccess(ctx.channel()); break; } switch (handshakeStatus) { case NEED_UNWRAP: break; case NEED_WRAP: wrapNonAppData(ctx, true); break; case NEED_TASK: runDelegatedTasks(); break; case FINISHED: setHandshakeSuccess(); wrapLater = true; continue; case NOT_HANDSHAKING: if (setHandshakeSuccessIfStillHandshaking()) { wrapLater = true; continue; } if (flushedBeforeHandshakeDone) { // We need to call wrap(...) in case there was a flush done before the handshake // completed. // // See https://github.com/netty/netty/pull/2437 flushedBeforeHandshakeDone = false; wrapLater = true; } break; default: throw new IllegalStateException("Unknown handshake status: " + handshakeStatus); } if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) { break; } } if (wrapLater) { wrap(ctx, true); } } catch (SSLException e) { setHandshakeFailure(e); throw e; } } private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException { int overflows = 0; for (; ; ) { ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); SSLEngineResult result = engine.unwrap(in, out0); out.writerIndex(out.writerIndex() + result.bytesProduced()); switch (result.getStatus()) { case BUFFER_OVERFLOW: int max = engine.getSession().getApplicationBufferSize(); switch (overflows++) { case 0: out.ensureWritable(Math.min(max, in.remaining())); break; default: out.ensureWritable(max); } break; default: return result; } } } /** * Fetches all delegated tasks from the {@link SSLEngine} and runs them by invoking them directly. */ private void runDelegatedTasks() { for (; ; ) { Runnable task = engine.getDelegatedTask(); if (task == null) { break; } task.run(); } } /** * Works around some Android {@link SSLEngine} implementations that skip {@link * HandshakeStatus#FINISHED} and go straight into {@link HandshakeStatus#NOT_HANDSHAKING} when * handshake is finished. * * @return {@code true} if and only if the workaround has been applied and thus {@link * #handshakeFuture} has been marked as success by this method */ private boolean setHandshakeSuccessIfStillHandshaking() { if (!handshakePromise.isDone()) { setHandshakeSuccess(); return true; } return false; } /** Notify all the handshake futures about the successfully handshake */ private void setHandshakeSuccess() { if (handshakePromise.trySuccess(ctx.channel())) { ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); } } /** Notify all the handshake futures about the failure during the handshake. */ private void setHandshakeFailure(Throwable cause) { // Release all resources such as internal buffers that SSLEngine // is managing. engine.closeOutbound(); try { engine.closeInbound(); } catch (SSLException e) { // only log in debug mode as it most likely harmless and latest chrome still trigger // this all the time. // // See https://github.com/netty/netty/issues/1340 String msg = e.getMessage(); if (msg == null || !msg.contains("possible truncation attack")) { logger.debug("SSLEngine.closeInbound() raised an exception.", e); } } notifyHandshakeFailure(cause); for (; ; ) { PendingWrite write = pendingUnencryptedWrites.poll(); if (write == null) { break; } write.failAndRecycle(cause); } } private void notifyHandshakeFailure(Throwable cause) { if (handshakePromise.tryFailure(cause)) { ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); ctx.close(); } } private void closeOutboundAndChannel( final ChannelHandlerContext ctx, final ChannelPromise promise, boolean disconnect) throws Exception { if (!ctx.channel().isActive()) { if (disconnect) { ctx.disconnect(promise); } else { ctx.close(promise); } return; } engine.closeOutbound(); ChannelPromise closeNotifyFuture = ctx.newPromise(); write(ctx, Unpooled.EMPTY_BUFFER, closeNotifyFuture); flush(ctx); safeClose(ctx, closeNotifyFuture, promise); } @Override public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; if (ctx.channel().isActive()) { // channelActive() event has been fired already, which means this.channelActive() will // not be invoked. We have to initialize here instead. handshake(); } else { // channelActive() event has not been fired yet. this.channelOpen() will be invoked // and initialization will occur there. } } private Future<Channel> handshake() { final ScheduledFuture<?> timeoutFuture; if (handshakeTimeoutMillis > 0) { timeoutFuture = ctx.executor() .schedule( new Runnable() { @Override public void run() { if (handshakePromise.isDone()) { return; } notifyHandshakeFailure(HANDSHAKE_TIMED_OUT); } }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); } else { timeoutFuture = null; } handshakePromise.addListener( new GenericFutureListener<Future<Channel>>() { @Override public void operationComplete(Future<Channel> f) throws Exception { if (timeoutFuture != null) { timeoutFuture.cancel(false); } } }); try { engine.beginHandshake(); wrapNonAppData(ctx, false); ctx.flush(); } catch (Exception e) { notifyHandshakeFailure(e); } return handshakePromise; } /** Issues a SSL handshake once connected when used in client-mode */ @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { if (!startTls && engine.getUseClientMode()) { // issue and handshake and add a listener to it which will fire an exception event if // an exception was thrown while doing the handshake handshake() .addListener( new GenericFutureListener<Future<Channel>>() { @Override public void operationComplete(Future<Channel> future) throws Exception { if (!future.isSuccess()) { logger.debug("Failed to complete handshake", future.cause()); ctx.close(); } } }); } ctx.fireChannelActive(); } private void safeClose( final ChannelHandlerContext ctx, ChannelFuture flushFuture, final ChannelPromise promise) { if (!ctx.channel().isActive()) { ctx.close(promise); return; } final ScheduledFuture<?> timeoutFuture; if (closeNotifyTimeoutMillis > 0) { // Force-close the connection if close_notify is not fully sent in time. timeoutFuture = ctx.executor() .schedule( new Runnable() { @Override public void run() { logger.warn( ctx.channel() + " last write attempt timed out." + " Force-closing the connection."); ctx.close(promise); } }, closeNotifyTimeoutMillis, TimeUnit.MILLISECONDS); } else { timeoutFuture = null; } // Close the connection if close_notify is sent in time. flushFuture.addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture f) throws Exception { if (timeoutFuture != null) { timeoutFuture.cancel(false); } // Trigger the close in all cases to make sure the promise is notified // See https://github.com/netty/netty/issues/2358 ctx.close(promise); } }); } private final class LazyChannelPromise extends DefaultPromise<Channel> { @Override protected EventExecutor executor() { if (ctx == null) { throw new IllegalStateException(); } return ctx.executor(); } } }
/** * A portable CDI extension which registers beans for lettuce. If there are no RedisURIs there are * also no registrations for RedisClients. * * @author <a href="mailto:[email protected]">Mark Paluch</a> */ public class LettuceCdiExtension implements Extension { private static final InternalLogger LOGGER = InternalLoggerFactory.getInstance(LettuceCdiExtension.class); private final Map<Set<Annotation>, Bean<RedisURI>> redisUris = new HashMap<Set<Annotation>, Bean<RedisURI>>(); public LettuceCdiExtension() { LOGGER.info("Activating CDI extension for lettuce."); } /** * Implementation of a an observer which checks for RedisURI beans and stores them in {@link * #redisUris} for later association with corresponding repository beans. * * @param <T> The type. * @param processBean The annotated type as defined by CDI. */ @SuppressWarnings("unchecked") <T> void processBean(@Observes ProcessBean<T> processBean) { Bean<T> bean = processBean.getBean(); for (Type type : bean.getTypes()) { // Check if the bean is an RedisURI. if (type instanceof Class<?> && RedisURI.class.isAssignableFrom((Class<?>) type)) { Set<Annotation> qualifiers = new HashSet<Annotation>(bean.getQualifiers()); if (bean.isAlternative() || !redisUris.containsKey(qualifiers)) { LOGGER.debug( String.format( "Discovered '%s' with qualifiers %s.", RedisURI.class.getName(), qualifiers)); redisUris.put(qualifiers, (Bean<RedisURI>) bean); } } } } /** * Implementation of a an observer which registers beans to the CDI container for the detected * RedisURIs. * * <p>The repository beans are associated to the EntityManagers using their qualifiers. * * @param beanManager The BeanManager instance. */ void afterBeanDiscovery( @Observes AfterBeanDiscovery afterBeanDiscovery, BeanManager beanManager) { int counter = 0; for (Entry<Set<Annotation>, Bean<RedisURI>> entry : redisUris.entrySet()) { Bean<RedisURI> redisUri = entry.getValue(); Set<Annotation> qualifiers = entry.getKey(); String clientBeanName = RedisClient.class.getSimpleName(); String clusterClientBeanName = RedisClusterClient.class.getSimpleName(); if (!contains(qualifiers, Default.class)) { clientBeanName += counter; clusterClientBeanName += counter; counter++; } RedisClientCdiBean clientBean = new RedisClientCdiBean(beanManager, qualifiers, redisUri, clientBeanName); register(afterBeanDiscovery, qualifiers, clientBean); RedisClusterClientCdiBean clusterClientBean = new RedisClusterClientCdiBean(beanManager, qualifiers, redisUri, clusterClientBeanName); register(afterBeanDiscovery, qualifiers, clusterClientBean); } } private boolean contains(Set<Annotation> qualifiers, Class<Default> defaultClass) { Optional<Annotation> result = Iterables.tryFind( qualifiers, new Predicate<Annotation>() { @Override public boolean apply(Annotation input) { return input instanceof Default; } }); return result.isPresent(); } private void register( AfterBeanDiscovery afterBeanDiscovery, Set<Annotation> qualifiers, Bean<?> bean) { LOGGER.info( String.format( "Registering bean '%s' with qualifiers %s.", bean.getBeanClass().getName(), qualifiers)); afterBeanDiscovery.addBean(bean); } }
/** * Reads a PEM file and converts it into a list of DERs so that they are imported into a {@link * KeyStore} easily. */ final class PemReader { private static final InternalLogger logger = InternalLoggerFactory.getInstance(PemReader.class); private static final Pattern CERT_PATTERN = Pattern.compile( "-+BEGIN\\s+.*CERTIFICATE[^-]*-+(?:\\s|\\r|\\n)+" + // Header "([a-z0-9+/=\\r\\n]+)" + // Base64 text "-+END\\s+.*CERTIFICATE[^-]*-+", // Footer Pattern.CASE_INSENSITIVE); private static final Pattern KEY_PATTERN = Pattern.compile( "-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+" + // Header "([a-z0-9+/=\\r\\n]+)" + // Base64 text "-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer Pattern.CASE_INSENSITIVE); static ByteBuf[] readCertificates(File file) throws CertificateException { String content; try { content = readContent(file); } catch (IOException e) { throw new CertificateException("failed to read a file: " + file, e); } List<ByteBuf> certs = new ArrayList<ByteBuf>(); Matcher m = CERT_PATTERN.matcher(content); int start = 0; for (; ; ) { if (!m.find(start)) { break; } ByteBuf base64 = Unpooled.copiedBuffer(m.group(1), CharsetUtil.US_ASCII); ByteBuf der = Base64.decode(base64); base64.release(); certs.add(der); start = m.end(); } if (certs.isEmpty()) { throw new CertificateException("found no certificates: " + file); } return certs.toArray(new ByteBuf[certs.size()]); } static ByteBuf readPrivateKey(File file) throws KeyException { String content; try { content = readContent(file); } catch (IOException e) { throw new KeyException("failed to read a file: " + file, e); } Matcher m = KEY_PATTERN.matcher(content); if (!m.find()) { throw new KeyException("found no private key: " + file); } ByteBuf base64 = Unpooled.copiedBuffer(m.group(1), CharsetUtil.US_ASCII); ByteBuf der = Base64.decode(base64); base64.release(); return der; } private static String readContent(File file) throws IOException { InputStream in = new FileInputStream(file); ByteArrayOutputStream out = new ByteArrayOutputStream(); try { byte[] buf = new byte[8192]; for (; ; ) { int ret = in.read(buf); if (ret < 0) { break; } out.write(buf, 0, ret); } return out.toString(CharsetUtil.US_ASCII.name()); } finally { safeClose(in); safeClose(out); } } private static void safeClose(InputStream in) { try { in.close(); } catch (IOException e) { logger.warn("Failed to close a stream.", e); } } private static void safeClose(OutputStream out) { try { out.close(); } catch (IOException e) { logger.warn("Failed to close a stream.", e); } } private PemReader() {} }
/** @author circlespainter */ public class AutoWebActorHandler extends WebActorHandler { private static final InternalLogger log = InternalLoggerFactory.getInstance(AutoWebActorHandler.class); private static final List<Class<?>> actorClasses = new ArrayList<>(32); private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0]; public AutoWebActorHandler() { this(null, null, null, null); } public AutoWebActorHandler(List<String> packagePrefixes) { this(null, null, packagePrefixes, null); } public AutoWebActorHandler(String httpResponseEncoderName, List<String> packagePrefixes) { this(httpResponseEncoderName, null, packagePrefixes, null); } public AutoWebActorHandler(String httpResponseEncoderName) { this(httpResponseEncoderName, null, null, null); } public AutoWebActorHandler(String httpResponseEncoderName, ClassLoader userClassLoader) { this(httpResponseEncoderName, userClassLoader, null, null); } public AutoWebActorHandler( String httpResponseEncoderName, ClassLoader userClassLoader, List<String> packagePrefixes) { this(httpResponseEncoderName, userClassLoader, packagePrefixes, null); } public AutoWebActorHandler(String httpResponseEncoderName, Map<Class<?>, Object[]> actorParams) { this(httpResponseEncoderName, null, null, actorParams); } public AutoWebActorHandler( String httpResponseEncoderName, List<String> packagePrefixes, Map<Class<?>, Object[]> actorParams) { this(httpResponseEncoderName, null, packagePrefixes, actorParams); } public AutoWebActorHandler( String httpResponseEncoderName, ClassLoader userClassLoader, List<String> packagePrefixes, Map<Class<?>, Object[]> actorParams) { super(null, httpResponseEncoderName); super.contextProvider = newContextProvider( userClassLoader != null ? userClassLoader : ClassLoader.getSystemClassLoader(), packagePrefixes, actorParams); } public AutoWebActorHandler(String httpResponseEncoderName, AutoContextProvider prov) { super(prov, httpResponseEncoderName); } protected AutoContextProvider newContextProvider( ClassLoader userClassLoader, List<String> packagePrefixes, Map<Class<?>, Object[]> actorParams) { return new AutoContextProvider(userClassLoader, packagePrefixes, actorParams); } public static class AutoContextProvider implements WebActorContextProvider { private final ClassLoader userClassLoader; private final List<String> packagePrefixes; private final Map<Class<?>, Object[]> actorParams; private final Long defaultContextValidityMS; public AutoContextProvider( ClassLoader userClassLoader, List<String> packagePrefixes, Map<Class<?>, Object[]> actorParams) { this(userClassLoader, packagePrefixes, actorParams, null); } public AutoContextProvider( ClassLoader userClassLoader, List<String> packagePrefixes, Map<Class<?>, Object[]> actorParams, Long defaultContextValidityMS) { this.userClassLoader = userClassLoader; this.packagePrefixes = packagePrefixes; this.actorParams = actorParams; this.defaultContextValidityMS = defaultContextValidityMS; } @Override public final Context get(final FullHttpRequest req) { final String sessionId = getSessionId(req); if (sessionId != null && sessionsEnabled()) { final Context actorContext = sessions.get(sessionId); if (actorContext != null) { if (actorContext.renew()) return actorContext; else sessions.remove(sessionId); // Evict session } } return newActorContext(req); } protected AutoContext newActorContext(FullHttpRequest req) { final AutoContext c = new AutoContext(req, packagePrefixes, actorParams, userClassLoader); if (defaultContextValidityMS != null) c.setValidityMS(defaultContextValidityMS); return c; } private String getSessionId(FullHttpRequest req) { final Set<Cookie> cookies = NettyHttpRequest.getNettyCookies(req); if (cookies != null) { for (final Cookie c : cookies) { if (c != null && SESSION_COOKIE_KEY.equals(c.name())) return c.value(); } } return null; } } private static class AutoContext extends DefaultContextImpl { private String id; private final List<String> packagePrefixes; private final Map<Class<?>, Object[]> actorParams; private final ClassLoader userClassLoader; private Class<? extends ActorImpl<? extends WebMessage>> actorClass; private ActorRef<? extends WebMessage> actorRef; public AutoContext( FullHttpRequest req, List<String> packagePrefixes, Map<Class<?>, Object[]> actorParams, ClassLoader userClassLoader) { this.packagePrefixes = packagePrefixes; this.actorParams = actorParams; this.userClassLoader = userClassLoader; fillActor(req); } private void fillActor(FullHttpRequest req) { final Pair<ActorRef<? extends WebMessage>, Class<? extends ActorImpl<? extends WebMessage>>> p = autoCreateActor(req); if (p != null) { actorRef = p.getFirst(); actorClass = p.getSecond(); } } @Override public final String getId() { return id != null ? id : (id = UUID.randomUUID().toString()); } @Override public final void restart(FullHttpRequest req) { renewed = new Date().getTime(); fillActor(req); } @Override public final ActorRef<? extends WebMessage> getWebActor() { return actorRef; } @Override public final boolean handlesWithHttp(String uri) { return WebActorHandler.handlesWithHttp(uri, actorClass); } @Override public final boolean handlesWithWebSocket(String uri) { return WebActorHandler.handlesWithWebSocket(uri, actorClass); } @Override public WatchPolicy watch() { return WatchPolicy.DIE_IF_EXCEPTION_ELSE_RESTART; } @SuppressWarnings("unchecked") private Pair<ActorRef<? extends WebMessage>, Class<? extends ActorImpl<? extends WebMessage>>> autoCreateActor(FullHttpRequest req) { registerActorClasses(); final String uri = req.getUri(); for (final Class<?> c : actorClasses) { if (WebActorHandler.handlesWithHttp(uri, c) || WebActorHandler.handlesWithWebSocket(uri, c)) return new Pair< ActorRef<? extends WebMessage>, Class<? extends ActorImpl<? extends WebMessage>>>( Actor.newActor( new ActorSpec( c, actorParams != null ? actorParams.get(c) : EMPTY_OBJECT_ARRAY)) .spawn(), (Class<? extends ActorImpl<? extends WebMessage>>) c); } return null; } private synchronized void registerActorClasses() { if (actorClasses.isEmpty()) { try { final ClassLoader classLoader = userClassLoader != null ? userClassLoader : this.getClass().getClassLoader(); ClassLoaderUtil.accept( (URLClassLoader) classLoader, new ClassLoaderUtil.Visitor() { @Override public final void visit(String resource, URL url, ClassLoader cl) { if (packagePrefixes != null) { boolean found = false; for (final String packagePrefix : packagePrefixes) { if (packagePrefix != null && resource.startsWith(packagePrefix.replace('.', '/'))) { found = true; break; } } if (!found) return; } if (!ClassLoaderUtil.isClassFile(resource)) return; final String className = ClassLoaderUtil.resourceToClass(resource); try (final InputStream is = cl.getResourceAsStream(resource)) { if (AnnotationUtil.hasClassAnnotation(WebActor.class, is)) registerWebActor(cl.loadClass(className)); } catch (final IOException | ClassNotFoundException e) { log.error( "Exception while scanning class " + className + " for WebActor annotation", e); throw new RuntimeException(e); } } private void registerWebActor(Class<?> c) { actorClasses.add(c); } }); } catch (final IOException e) { log.error("IOException while scanning classes for WebActor annotation", e); } } } } }
/** * {@link ServerSocketChannel} which accepts new connections and create the {@link * OioSocketChannel}'s for them. * * <p>This implementation use Old-Blocking-IO. */ public class OioServerSocketChannel extends AbstractOioMessageChannel implements ServerSocketChannel { private static final InternalLogger logger = InternalLoggerFactory.getInstance(OioServerSocketChannel.class); private static final ChannelMetadata METADATA = new ChannelMetadata(false); private static ServerSocket newServerSocket() { try { return new ServerSocket(); } catch (IOException e) { throw new ChannelException("failed to create a server socket", e); } } final ServerSocket socket; final Lock shutdownLock = new ReentrantLock(); private final OioServerSocketChannelConfig config; /** Create a new instance with an new {@link Socket} */ public OioServerSocketChannel() { this(newServerSocket()); } /** * Create a new instance from the given {@link ServerSocket} * * @param socket the {@link ServerSocket} which is used by this instance */ public OioServerSocketChannel(ServerSocket socket) { super(null); if (socket == null) { throw new NullPointerException("socket"); } boolean success = false; try { socket.setSoTimeout(SO_TIMEOUT); success = true; } catch (IOException e) { throw new ChannelException("Failed to set the server socket timeout.", e); } finally { if (!success) { try { socket.close(); } catch (IOException e) { if (logger.isWarnEnabled()) { logger.warn("Failed to close a partially initialized socket.", e); } } } } this.socket = socket; config = new DefaultOioServerSocketChannelConfig(this, socket); } @Override public InetSocketAddress localAddress() { return (InetSocketAddress) super.localAddress(); } @Override public ChannelMetadata metadata() { return METADATA; } @Override public OioServerSocketChannelConfig config() { return config; } @Override public InetSocketAddress remoteAddress() { return null; } @Override public boolean isOpen() { return !socket.isClosed(); } @Override public boolean isActive() { return isOpen() && socket.isBound(); } @Override protected SocketAddress localAddress0() { return socket.getLocalSocketAddress(); } @Override protected void doBind(SocketAddress localAddress) throws Exception { socket.bind(localAddress, config.getBacklog()); } @Override protected void doClose() throws Exception { socket.close(); } @Override protected int doReadMessages(MessageList<Object> buf) throws Exception { if (socket.isClosed()) { return -1; } try { Socket s = socket.accept(); try { if (s != null) { buf.add(new OioSocketChannel(this, s)); return 1; } } catch (Throwable t) { logger.warn("Failed to create a new channel from an accepted socket.", t); if (s != null) { try { s.close(); } catch (Throwable t2) { logger.warn("Failed to close a socket.", t2); } } } } catch (SocketTimeoutException e) { // Expected } return 0; } @Override protected int doWrite(MessageList<Object> msgs, int index) throws Exception { throw new UnsupportedOperationException(); } @Override protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { throw new UnsupportedOperationException(); } @Override protected SocketAddress remoteAddress0() { return null; } @Override protected void doDisconnect() throws Exception { throw new UnsupportedOperationException(); } }
/** @author circlespainter */ public class WebActorHandler extends SimpleChannelInboundHandler<Object> { // @FunctionalInterface public interface WebActorContextProvider { Context get(ChannelHandlerContext ctx, FullHttpRequest req); } public interface Context { boolean isValid(); void invalidate(); ActorRef<? extends WebMessage> getRef(); ReentrantLock getLock(); Map<String, Object> getAttachments(); boolean handlesWithHttp(String uri); boolean handlesWithWebSocket(String uri); boolean watch(); } public abstract static class DefaultContextImpl implements Context { private static final String durationProp = System.getProperty(DefaultContextImpl.class.getName() + ".durationMillis"); private static final long DURATION = durationProp != null ? Long.parseLong(durationProp) : 60_000L; private final ReentrantLock lock = new ReentrantLock(); private final long created; private final Map<String, Object> attachments = new HashMap<>(); private boolean valid = true; public DefaultContextImpl() { created = new Date().getTime(); } @Override public void invalidate() { attachments.clear(); valid = false; } @Override public final boolean isValid() { final boolean ret = valid && (new Date().getTime() - created) <= DURATION; if (!ret) invalidate(); return ret; } @Override public final Map<String, Object> getAttachments() { return attachments; } @Override public final ReentrantLock getLock() { return lock; } @Override public boolean watch() { return true; } } public WebActorHandler(WebActorContextProvider selector) { this(selector, null); } public WebActorHandler(WebActorContextProvider selector, String httpResponseEncoderName) { this.selector = selector; this.httpResponseEncoderName = httpResponseEncoderName; } @Override public final void channelReadComplete(ChannelHandlerContext ctx) { ctx.flush(); } @Override public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { if (ctx.channel().isOpen()) ctx.close(); log.error("Exception caught", cause); } @Override protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof FullHttpRequest) { handleHttpRequest(ctx, (FullHttpRequest) msg); } else if (msg instanceof WebSocketFrame) { handleWebSocketFrame(ctx, (WebSocketFrame) msg); } else { throw new AssertionError("Unexpected message " + msg); } } protected static boolean sessionsEnabled() { return "always".equals(trackSession) || "sse".equals(trackSession); } protected static final String SESSION_COOKIE_KEY = "JSESSIONID"; protected static final Map<String, Context> sessions = Collections.synchronizedMap(new WeakHashMap<String, Context>()); protected static final String TRACK_SESSION_PROP = HttpChannelAdapter.class.getName() + ".trackSession"; protected static final String trackSession = System.getProperty(TRACK_SESSION_PROP, "sse"); protected static final String OMIT_DATE_HEADER_PROP = HttpChannelAdapter.class.getName() + ".omitDateHeader"; protected static final Boolean omitDateHeader = SystemProperties.isEmptyOrTrue(OMIT_DATE_HEADER_PROP); private static final String ACTOR_KEY = "co.paralleluniverse.comsat.webactors.sessionActor"; private static final WeakHashMap<Class<?>, List<Pair<String, String>>> classToUrlPatterns = new WeakHashMap<>(); private static final InternalLogger log = InternalLoggerFactory.getInstance(AutoWebActorHandler.class); private final WebActorContextProvider selector; private final String httpResponseEncoderName; private WebSocketServerHandshaker handshaker; private WebSocketActorAdapter webSocketActor; private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { // Check for closing frame if (frame instanceof CloseWebSocketFrame) { handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain()); return; } if (frame instanceof PingWebSocketFrame) { ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain())); return; } if (frame instanceof ContinuationWebSocketFrame) return; if (frame instanceof TextWebSocketFrame) webSocketActor.onMessage(((TextWebSocketFrame) frame).text()); else webSocketActor.onMessage(frame.content().nioBuffer()); } private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) throws SuspendExecution { // Handle a bad request. if (!req.getDecoderResult().isSuccess()) { sendHttpResponse( ctx, req, new DefaultFullHttpResponse(req.getProtocolVersion(), BAD_REQUEST), false); return; } final String uri = req.getUri(); final Context actorCtx = selector.get(ctx, req); assert actorCtx != null; final ReentrantLock lock = actorCtx.getLock(); assert lock != null; lock.lock(); try { final ActorRef<? extends WebMessage> userActorRef = actorCtx.getRef(); ActorImpl internalActor = (ActorImpl) actorCtx.getAttachments().get(ACTOR_KEY); if (userActorRef != null) { if (actorCtx.handlesWithWebSocket(uri)) { if (internalActor == null || !(internalActor instanceof WebSocketActorAdapter)) { //noinspection unchecked webSocketActor = new WebSocketActorAdapter(ctx, (ActorRef<? super WebMessage>) userActorRef); addActorToContextAndUnlock(actorCtx, webSocketActor, lock); } // Handshake final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(uri, null, true); handshaker = wsFactory.newHandshaker(req); if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { @SuppressWarnings("unchecked") final ActorRef<WebMessage> userActorRef0 = (ActorRef<WebMessage>) webSocketActor.userActor; handshaker .handshake(ctx.channel(), req) .addListener( new GenericFutureListener<ChannelFuture>() { @Override public void operationComplete(ChannelFuture future) throws Exception { FiberUtil.runInFiber( new SuspendableRunnable() { @Override public void run() throws SuspendExecution, InterruptedException { userActorRef0.send( new WebSocketOpened(WebActorHandler.this.webSocketActor.ref())); } }); } }); } return; } else if (actorCtx.handlesWithHttp(uri)) { if (internalActor == null || !(internalActor instanceof HttpActorAdapter)) { //noinspection unchecked internalActor = new HttpActorAdapter( (ActorRef<HttpRequest>) userActorRef, actorCtx, httpResponseEncoderName); addActorToContextAndUnlock(actorCtx, internalActor, lock); } //noinspection unchecked ((HttpActorAdapter) internalActor).service(ctx, req); return; } } } finally { if (lock.isHeldByCurrentStrand() && lock.isLocked()) lock.unlock(); } sendHttpResponse( ctx, req, new DefaultFullHttpResponse(req.getProtocolVersion(), NOT_FOUND), false); } private void addActorToContextAndUnlock( Context actorContext, ActorImpl actor, ReentrantLock lock) { actorContext.getAttachments().put(ACTOR_KEY, actor); lock.unlock(); } private static final class WebSocketActorAdapter extends FakeActor<WebDataMessage> { ActorRef<? super WebMessage> userActor; private ChannelHandlerContext ctx; public WebSocketActorAdapter( ChannelHandlerContext ctx, ActorRef<? super WebMessage> userActor) { super(userActor.getName(), new WebSocketChannelAdapter(ctx)); ((WebSocketChannelAdapter) (SendPort) getMailbox()).actor = this; this.ctx = ctx; this.userActor = userActor; watch(userActor); } @Override public final void interrupt() { die(new InterruptedException()); } @Override public final String toString() { return "WebSocketActorAdapter{" + "userActor=" + userActor + '}'; } private void onMessage(final ByteBuffer message) { try { userActor.send(new WebDataMessage(ref(), message)); } catch (SuspendExecution ex) { throw new AssertionError(ex); } } private void onMessage(final String message) { try { userActor.send(new WebDataMessage(ref(), message)); } catch (SuspendExecution ex) { throw new AssertionError(ex); } } @Override protected final WebDataMessage handleLifecycleMessage(LifecycleMessage m) { if (m instanceof ExitMessage) { ExitMessage em = (ExitMessage) m; if (em.getActor() != null && em.getActor().equals(userActor)) die(em.getCause()); } return null; } @Override protected final void throwIn(RuntimeException e) { die(e); } @Override protected final void die(Throwable cause) { super.die(cause); if (ctx.channel().isOpen()) ctx.close(); // Ensure to release server references userActor = null; ctx = null; } } private static final class WebSocketChannelAdapter implements SendPort<WebDataMessage> { private final ChannelHandlerContext ctx; WebSocketActorAdapter actor; public WebSocketChannelAdapter(ChannelHandlerContext ctx) { this.ctx = ctx; } @Override public final void send(WebDataMessage message) throws SuspendExecution, InterruptedException { trySend(message); } @Override public final boolean send(WebDataMessage message, long timeout, TimeUnit unit) throws SuspendExecution, InterruptedException { return trySend(message); } @Override public final boolean send(WebDataMessage message, Timeout timeout) throws SuspendExecution, InterruptedException { return send(message, timeout.nanosLeft(), TimeUnit.NANOSECONDS); } @Override public final boolean trySend(WebDataMessage message) { if (!message.isBinary()) ctx.writeAndFlush(new TextWebSocketFrame(message.getStringBody())); else ctx.writeAndFlush( new BinaryWebSocketFrame(Unpooled.wrappedBuffer(message.getByteBufferBody()))); return true; } @Override public final void close() { if (ctx.channel().isOpen()) ctx.close(); if (actor != null) actor.die(null); } @Override public final void close(Throwable t) { if (actor != null) actor.die(t); close(); } } private static final class HttpActorAdapter extends FakeActor<HttpResponse> { private ActorRef<? super HttpRequest> userActor; private Context context; private volatile boolean dead; private volatile ChannelHandlerContext ctx; private volatile FullHttpRequest req; private volatile Object watchToken; HttpActorAdapter( ActorRef<? super HttpRequest> userActor, Context actorContext, String httpResponseEncoderName) { super("HttpActorAdapter", new HttpChannelAdapter(actorContext, httpResponseEncoderName)); if (actorContext.watch()) ((HttpChannelAdapter) (SendPort) getMailbox()).actor = this; this.userActor = userActor; this.context = actorContext; } final void service(ChannelHandlerContext ctx, FullHttpRequest req) throws SuspendExecution { if (context.watch()) watchToken = watch(userActor); this.ctx = ctx; this.req = req; if (isDone()) { handleDeath(getDeathCause()); return; } userActor.send(new HttpRequestWrapper(ref(), ctx, req)); } final void unwatch() { if (watchToken != null && userActor != null) { unwatch(userActor, watchToken); watchToken = null; } } private void handleDeath(Throwable cause) { if (cause != null) sendHttpResponse( ctx, req, new DefaultFullHttpResponse( req.getProtocolVersion(), INTERNAL_SERVER_ERROR, Unpooled.wrappedBuffer( ("Actor is dead because of " + cause.getMessage()).getBytes())), false); else sendHttpResponse( ctx, req, new DefaultFullHttpResponse( req.getProtocolVersion(), INTERNAL_SERVER_ERROR, Unpooled.wrappedBuffer(("Actor has terminated.").getBytes())), false); } @Override protected final HttpResponse handleLifecycleMessage(LifecycleMessage m) { if (m instanceof ExitMessage) { final ExitMessage em = (ExitMessage) m; if (em.getActor() != null && em.getActor().equals(userActor)) { handleDeath(em.getCause()); die(em.getCause()); } } return null; } @Override protected final void die(Throwable cause) { if (dead) return; dead = true; super.die(cause); try { context.invalidate(); } catch (final Exception ignored) { } // Ensure to release references to server objects unwatch(); userActor = null; watchToken = null; context = null; ctx = null; req = null; } @Override protected final void throwIn(RuntimeException e) { die(e); } @Override protected final void interrupt() { die(new InterruptedException()); } @Override public final String toString() { return "HttpActorAdapter{" + userActor + "}"; } } private static final class HttpChannelAdapter implements SendPort<HttpResponse> { HttpActorAdapter actor; private final String httpResponseEncoderName; private Context actorContext; public HttpChannelAdapter(Context actorContext, String httpResponseEncoderName) { this.actorContext = actorContext; this.httpResponseEncoderName = httpResponseEncoderName; } @Override public final void send(HttpResponse message) throws SuspendExecution, InterruptedException { trySend(message); } @Override public final boolean send(HttpResponse message, long timeout, TimeUnit unit) throws SuspendExecution, InterruptedException { send(message); return true; } @Override public final boolean send(HttpResponse message, Timeout timeout) throws SuspendExecution, InterruptedException { return send(message, timeout.nanosLeft(), TimeUnit.NANOSECONDS); } @Override public final boolean trySend(HttpResponse message) { try { final HttpRequestWrapper nettyRequest = (HttpRequestWrapper) message.getRequest(); final FullHttpRequest req = nettyRequest.req; final ChannelHandlerContext ctx = nettyRequest.ctx; final HttpResponseStatus status = HttpResponseStatus.valueOf(message.getStatus()); if (message.getStatus() >= 400 && message.getStatus() < 600) { sendHttpResponse( ctx, req, new DefaultFullHttpResponse(req.getProtocolVersion(), status), false); close(); return true; } if (message.getRedirectPath() != null) { sendHttpRedirect(ctx, req, message.getRedirectPath()); close(); return true; } FullHttpResponse res; if (message.getStringBody() != null) res = new DefaultFullHttpResponse( req.getProtocolVersion(), status, Unpooled.wrappedBuffer(message.getStringBody().getBytes())); else if (message.getByteBufferBody() != null) res = new DefaultFullHttpResponse( req.getProtocolVersion(), status, Unpooled.wrappedBuffer(message.getByteBufferBody())); else res = new DefaultFullHttpResponse(req.getProtocolVersion(), status); if (message.getCookies() != null) { final ServerCookieEncoder enc = ServerCookieEncoder.STRICT; for (final Cookie c : message.getCookies()) HttpHeaders.setHeader(res, COOKIE, enc.encode(getNettyCookie(c))); } if (message.getHeaders() != null) { for (final Map.Entry<String, String> h : message.getHeaders().entries()) HttpHeaders.setHeader(res, h.getKey(), h.getValue()); } if (message.getContentType() != null) { String ct = message.getContentType(); if (message.getCharacterEncoding() != null) ct = ct + "; charset=" + message.getCharacterEncoding().name(); HttpHeaders.setHeader(res, CONTENT_TYPE, ct); } final boolean sseStarted = message.shouldStartActor(); if (trackSession(sseStarted)) { final String sessionId = UUID.randomUUID().toString(); res.headers() .add(SET_COOKIE, ServerCookieEncoder.STRICT.encode(SESSION_COOKIE_KEY, sessionId)); startSession(sessionId, actorContext); } if (!sseStarted) { final String stringBody = message.getStringBody(); long contentLength = 0L; if (stringBody != null) contentLength = stringBody.getBytes().length; else { final ByteBuffer byteBufferBody = message.getByteBufferBody(); if (byteBufferBody != null) contentLength = byteBufferBody.remaining(); } res.headers().add(CONTENT_LENGTH, contentLength); } final HttpStreamActorAdapter httpStreamActorAdapter; if (sseStarted) // This will copy the request content, which must still be referenceable, doing before the // request handler // unallocates it (unfortunately it is explicitly reference-counted in Netty) httpStreamActorAdapter = new HttpStreamActorAdapter(ctx, req); else httpStreamActorAdapter = null; sendHttpResponse(ctx, req, res, false); if (sseStarted) { if (httpResponseEncoderName != null) { ctx.pipeline().remove(httpResponseEncoderName); } else { final ChannelPipeline pl = ctx.pipeline(); final List<String> handlerKeysToBeRemoved = new ArrayList<>(); for (final Map.Entry<String, ChannelHandler> e : pl) { if (e.getValue() instanceof HttpResponseEncoder) handlerKeysToBeRemoved.add(e.getKey()); } for (final String k : handlerKeysToBeRemoved) pl.remove(k); } try { message.getFrom().send(new HttpStreamOpened(httpStreamActorAdapter.ref(), message)); } catch (SuspendExecution e) { throw new AssertionError(e); } } return true; } finally { if (actor != null) actor.unwatch(); } } private io.netty.handler.codec.http.cookie.Cookie getNettyCookie(Cookie c) { io.netty.handler.codec.http.cookie.Cookie ret = new io.netty.handler.codec.http.cookie.DefaultCookie(c.getName(), c.getValue()); ret.setDomain(c.getDomain()); ret.setHttpOnly(c.isHttpOnly()); ret.setMaxAge(c.getMaxAge()); ret.setPath(c.getPath()); ret.setSecure(c.isSecure()); return ret; } @Override public final void close() { if (actor != null) actor.die(null); actorContext = null; } @Override public final void close(Throwable t) { log.error("Exception while closing HTTP adapter", t); if (actor != null) actor.die(t); } } protected static boolean trackSession(boolean sseStarted) { return trackSession != null && ("always".equals(trackSession) || sseStarted && "sse".equals(trackSession)); } protected static boolean handlesWithHttp(String uri, Class<?> actorClass) { return match(uri, actorClass).equals("websocket"); } protected static boolean handlesWithWebSocket(String uri, Class<?> actorClass) { return match(uri, actorClass).equals("ws"); } private static final class HttpStreamActorAdapter extends FakeActor<WebDataMessage> { private volatile boolean dead; public HttpStreamActorAdapter(final ChannelHandlerContext ctx, final FullHttpRequest req) { super(req.toString(), new HttpStreamChannelAdapter(ctx, req)); ((HttpStreamChannelAdapter) (SendPort) getMailbox()).actor = this; } @Override protected WebDataMessage handleLifecycleMessage(LifecycleMessage m) { if (m instanceof ShutdownMessage) { die(null); } return null; } @Override protected void throwIn(RuntimeException e) { die(e); } @Override public void interrupt() { die(new InterruptedException()); } @Override protected void die(Throwable cause) { if (dead) return; this.dead = true; mailbox().close(); super.die(cause); } @Override public String toString() { return "HttpStreamActorAdapter{request + " + getName() + "}"; } } private static final class HttpStreamChannelAdapter implements SendPort<WebDataMessage> { private final Charset encoding; private final ChannelHandlerContext ctx; HttpStreamActorAdapter actor; public HttpStreamChannelAdapter(ChannelHandlerContext ctx, FullHttpRequest req) { this.ctx = ctx; this.encoding = HttpRequestWrapper.extractCharacterEncodingOrDefault(req.headers()); } @Override public final void send(WebDataMessage message) throws SuspendExecution, InterruptedException { trySend(message); } @Override public final boolean send(WebDataMessage message, long timeout, TimeUnit unit) throws SuspendExecution, InterruptedException { send(message); return true; } @Override public final boolean send(WebDataMessage message, Timeout timeout) throws SuspendExecution, InterruptedException { return send(message, timeout.nanosLeft(), TimeUnit.NANOSECONDS); } @Override public final boolean trySend(WebDataMessage res) { final ByteBuf buf; final String stringBody = res.getStringBody(); if (stringBody != null) { byte[] bs = stringBody.getBytes(encoding); buf = Unpooled.wrappedBuffer(bs); } else { buf = Unpooled.wrappedBuffer(res.getByteBufferBody()); } ctx.writeAndFlush(buf); return true; } @Override public final void close() { if (ctx.channel().isOpen()) ctx.close(); if (actor != null) actor.die(null); } @Override public final void close(Throwable t) { if (actor != null) actor.die(t); close(); } } private static void sendHttpResponse( ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res, Boolean close) { writeHttpResponse(ctx, req, res, close); } private static void sendHttpRedirect( ChannelHandlerContext ctx, FullHttpRequest req, String newUri) { final FullHttpResponse res = new DefaultFullHttpResponse(req.getProtocolVersion(), FOUND); HttpHeaders.setHeader(res, LOCATION, newUri); writeHttpResponse(ctx, req, res, true); } private static void writeHttpResponse( ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res, Boolean close) { if (!omitDateHeader && !res.headers().contains(DefaultHttpHeaders.Names.DATE)) DefaultHttpHeaders.addDateHeader(res, DefaultHttpHeaders.Names.DATE, new Date()); // Send the response and close the connection if necessary. if (!HttpHeaders.isKeepAlive(req) || res.getStatus().code() != 200 || close == null || close) { res.headers().set(CONNECTION, HttpHeaders.Values.CLOSE); ctx.writeAndFlush(res).addListener(ChannelFutureListener.CLOSE); } else { res.headers().set(CONNECTION, HttpHeaders.Values.KEEP_ALIVE); write(ctx, res); } } private static ChannelFuture write(ChannelHandlerContext ctx, Object res) { return ctx.writeAndFlush(res); // : ctx.write(res); } private static String match(String uri, Class<?> actorClass) { if (uri != null && actorClass != null) { for (final Pair<String, String> e : lookupOrInsert(actorClass)) { if (servletMatch(e.getFirst(), uri)) return e.getSecond(); } } return ""; } private static List<Pair<String, String>> lookupOrInsert(Class<?> actorClass) { if (actorClass != null) { final List<Pair<String, String>> lookup = classToUrlPatterns.get(actorClass); if (lookup != null) return lookup; return insert(actorClass); } return null; } private static List<Pair<String, String>> insert(Class<?> actorClass) { if (actorClass != null) { final WebActor wa = actorClass.getAnnotation(WebActor.class); final List<Pair<String, String>> ret = new ArrayList<>(4); for (String httpP : wa.httpUrlPatterns()) addPattern(ret, httpP, "websocket"); for (String wsP : wa.webSocketUrlPatterns()) addPattern(ret, wsP, "ws"); classToUrlPatterns.put(actorClass, ret); return ret; } return null; } private static void addPattern(List<Pair<String, String>> ret, String p, String type) { if (p != null) { @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") final Pair<String, String> entry = new Pair<>(p, type); if (p.endsWith("*") || p.startsWith("*.") || p.equals("/")) // Wildcard -> end ret.add(entry); else // Exact -> beginning ret.add(0, entry); } } private static boolean servletMatch(String pattern, String uri) { // As per servlet spec if (pattern != null && uri != null) { if (pattern.startsWith("/") && pattern.endsWith("*")) return uri.startsWith(pattern.substring(0, pattern.length() - 1)); if (pattern.startsWith("*.")) return uri.endsWith(pattern.substring(2)); if (pattern.isEmpty()) return uri.equals("/"); return pattern.equals("/") || pattern.equals(uri); } return false; } private static void startSession(String sessionId, Context actorContext) { sessions.put(sessionId, actorContext); } }
class EpollSocketTestPermutation extends SocketTestPermutation { static final EpollSocketTestPermutation INSTANCE = new EpollSocketTestPermutation(); static final EventLoopGroup EPOLL_BOSS_GROUP = new EpollEventLoopGroup(BOSSES, new DefaultThreadFactory("testsuite-epoll-boss", true)); static final EventLoopGroup EPOLL_WORKER_GROUP = new EpollEventLoopGroup(WORKERS, new DefaultThreadFactory("testsuite-epoll-worker", true)); private static final InternalLogger logger = InternalLoggerFactory.getInstance(EpollSocketTestPermutation.class); @Override public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> socket() { List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list = combo(serverSocket(), clientSocket()); list.remove(list.size() - 1); // Exclude NIO x NIO test return list; } @SuppressWarnings("unchecked") @Override public List<BootstrapFactory<ServerBootstrap>> serverSocket() { List<BootstrapFactory<ServerBootstrap>> toReturn = new ArrayList<BootstrapFactory<ServerBootstrap>>(); toReturn.add( new BootstrapFactory<ServerBootstrap>() { @Override public ServerBootstrap newInstance() { return new ServerBootstrap() .group(EPOLL_BOSS_GROUP, EPOLL_WORKER_GROUP) .channel(EpollServerSocketChannel.class); } }); if (isServerFastOpen()) { toReturn.add( new BootstrapFactory<ServerBootstrap>() { @Override public ServerBootstrap newInstance() { ServerBootstrap serverBootstrap = new ServerBootstrap() .group(EPOLL_BOSS_GROUP, EPOLL_WORKER_GROUP) .channel(EpollServerSocketChannel.class); serverBootstrap.option(EpollChannelOption.TCP_FASTOPEN, 5); return serverBootstrap; } }); } toReturn.add( new BootstrapFactory<ServerBootstrap>() { @Override public ServerBootstrap newInstance() { return new ServerBootstrap() .group(nioBossGroup, nioWorkerGroup) .channel(NioServerSocketChannel.class); } }); return toReturn; } @SuppressWarnings("unchecked") @Override public List<BootstrapFactory<Bootstrap>> clientSocket() { return Arrays.asList( new BootstrapFactory<Bootstrap>() { @Override public Bootstrap newInstance() { return new Bootstrap().group(EPOLL_WORKER_GROUP).channel(EpollSocketChannel.class); } }, new BootstrapFactory<Bootstrap>() { @Override public Bootstrap newInstance() { return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class); } }); } @Override public List<TestsuitePermutation.BootstrapComboFactory<Bootstrap, Bootstrap>> datagram() { // Make the list of Bootstrap factories. @SuppressWarnings("unchecked") List<BootstrapFactory<Bootstrap>> bfs = Arrays.asList( new BootstrapFactory<Bootstrap>() { @Override public Bootstrap newInstance() { return new Bootstrap() .group(nioWorkerGroup) .channelFactory( new ChannelFactory<Channel>() { @Override public Channel newChannel() { return new NioDatagramChannel(InternetProtocolFamily.IPv4); } @Override public String toString() { return NioDatagramChannel.class.getSimpleName() + ".class"; } }); } }, new BootstrapFactory<Bootstrap>() { @Override public Bootstrap newInstance() { return new Bootstrap() .group(EPOLL_WORKER_GROUP) .channel(EpollDatagramChannel.class); } }); return combo(bfs, bfs); } public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> domainSocket() { List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list = combo(serverDomainSocket(), clientDomainSocket()); return list; } public List<BootstrapFactory<ServerBootstrap>> serverDomainSocket() { return Collections.<BootstrapFactory<ServerBootstrap>>singletonList( new BootstrapFactory<ServerBootstrap>() { @Override public ServerBootstrap newInstance() { return new ServerBootstrap() .group(EPOLL_BOSS_GROUP, EPOLL_WORKER_GROUP) .channel(EpollServerDomainSocketChannel.class); } }); } public List<BootstrapFactory<Bootstrap>> clientDomainSocket() { return Collections.<BootstrapFactory<Bootstrap>>singletonList( new BootstrapFactory<Bootstrap>() { @Override public Bootstrap newInstance() { return new Bootstrap() .group(EPOLL_WORKER_GROUP) .channel(EpollDomainSocketChannel.class); } }); } public boolean isServerFastOpen() { return AccessController.doPrivileged( new PrivilegedAction<Integer>() { @Override public Integer run() { int fastopen = 0; File file = new File("/proc/sys/net/ipv4/tcp_fastopen"); if (file.exists()) { BufferedReader in = null; try { in = new BufferedReader(new FileReader(file)); fastopen = Integer.parseInt(in.readLine()); if (logger.isDebugEnabled()) { logger.debug("{}: {}", file, fastopen); } } catch (Exception e) { logger.debug("Failed to get TCP_FASTOPEN from: {}", file, e); } finally { if (in != null) { try { in.close(); } catch (Exception e) { // Ignored. } } } } else { if (logger.isDebugEnabled()) { logger.debug("{}: {} (non-existent)", file, fastopen); } } return fastopen; } }) == 3; } public static DomainSocketAddress newSocketAddress() { try { File file = File.createTempFile("netty", "dsocket"); file.delete(); return new DomainSocketAddress(file); } catch (IOException e) { throw new IllegalStateException(e); } } }
/** * Provides the default implementation for processing inbound frame events and delegates to a {@link * Http2FrameListener} * * <p>This class will read HTTP/2 frames and delegate the events to a {@link Http2FrameListener} * * <p>This interface enforces inbound flow control functionality through {@link * Http2LocalFlowController} */ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http2LifecycleManager { private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2ConnectionHandler.class); private final Http2ConnectionDecoder decoder; private final Http2ConnectionEncoder encoder; private final Http2Settings initialSettings; private ChannelFutureListener closeListener; private BaseDecoder byteDecoder; public Http2ConnectionHandler(boolean server, Http2FrameListener listener) { this(new DefaultHttp2Connection(server), listener); } public Http2ConnectionHandler(Http2Connection connection, Http2FrameListener listener) { this(connection, new DefaultHttp2FrameReader(), new DefaultHttp2FrameWriter(), listener); } public Http2ConnectionHandler( Http2Connection connection, Http2FrameReader frameReader, Http2FrameWriter frameWriter, Http2FrameListener listener) { initialSettings = null; encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader, listener); } /** * Constructor for pre-configured encoder and decoder. Just sets the {@code this} as the {@link * Http2LifecycleManager} and builds them. */ public Http2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder) { this.initialSettings = null; this.decoder = checkNotNull(decoder, "decoder"); this.encoder = checkNotNull(encoder, "encoder"); if (encoder.connection() != decoder.connection()) { throw new IllegalArgumentException( "Encoder and Decoder do not share the same connection object"); } } public Http2ConnectionHandler( Http2Connection connection, Http2FrameListener listener, Http2Settings initialSettings) { this( connection, new DefaultHttp2FrameReader(), new DefaultHttp2FrameWriter(), listener, initialSettings); } public Http2ConnectionHandler( Http2Connection connection, Http2FrameReader frameReader, Http2FrameWriter frameWriter, Http2FrameListener listener, Http2Settings initialSettings) { this.initialSettings = initialSettings; encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader, listener); } public Http2ConnectionHandler( Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings) { this.initialSettings = initialSettings; this.decoder = checkNotNull(decoder, "decoder"); this.encoder = checkNotNull(encoder, "encoder"); if (encoder.connection() != decoder.connection()) { throw new IllegalArgumentException( "Encoder and Decoder do not share the same connection object"); } } public Http2Connection connection() { return encoder.connection(); } public Http2ConnectionDecoder decoder() { return decoder; } public Http2ConnectionEncoder encoder() { return encoder; } private boolean prefaceSent() { return byteDecoder != null && byteDecoder.prefaceSent(); } /** * Handles the client-side (cleartext) upgrade from HTTP to HTTP/2. Reserves local stream 1 for * the HTTP/2 response. */ public void onHttpClientUpgrade() throws Http2Exception { if (connection().isServer()) { throw connectionError(PROTOCOL_ERROR, "Client-side HTTP upgrade requested for a server"); } if (prefaceSent() || decoder.prefaceReceived()) { throw connectionError( PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received"); } // Create a local stream used for the HTTP cleartext upgrade. connection().local().createStream(HTTP_UPGRADE_STREAM_ID, true); } /** * Handles the server-side (cleartext) upgrade from HTTP to HTTP/2. * * @param settings the settings for the remote endpoint. */ public void onHttpServerUpgrade(Http2Settings settings) throws Http2Exception { if (!connection().isServer()) { throw connectionError(PROTOCOL_ERROR, "Server-side HTTP upgrade requested for a client"); } if (prefaceSent() || decoder.prefaceReceived()) { throw connectionError( PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received"); } // Apply the settings but no ACK is necessary. encoder.remoteSettings(settings); // Create a stream in the half-closed state. connection().remote().createStream(HTTP_UPGRADE_STREAM_ID, true); } @Override public void flush(ChannelHandlerContext ctx) throws Http2Exception { // Trigger pending writes in the remote flow controller. connection().remote().flowController().writePendingBytes(); try { super.flush(ctx); } catch (Throwable t) { throw new Http2Exception(INTERNAL_ERROR, "Error flushing", t); } } private abstract class BaseDecoder { public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception; public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {} public void channelActive(ChannelHandlerContext ctx) throws Exception {} public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { final Http2Connection connection = connection(); // Check if there are streams to avoid the overhead of creating the ChannelFuture. if (connection.numActiveStreams() > 0) { final ChannelFuture future = ctx.newSucceededFuture(); connection.forEachActiveStream( new Http2StreamVisitor() { @Override public boolean visit(Http2Stream stream) throws Http2Exception { closeStream(stream, future); return true; } }); } } finally { try { encoder().close(); } finally { decoder().close(); } } } /** Determine if the HTTP/2 connection preface been sent. */ public boolean prefaceSent() { return true; } } private final class PrefaceDecoder extends BaseDecoder { private ByteBuf clientPrefaceString; private boolean prefaceSent; public PrefaceDecoder(ChannelHandlerContext ctx) { clientPrefaceString = clientPrefaceString(encoder.connection()); // This handler was just added to the context. In case it was handled after // the connection became active, send the connection preface now. sendPreface(ctx); } @Override public boolean prefaceSent() { return prefaceSent; } @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { try { if (readClientPrefaceString(in) && verifyFirstFrameIsSettings(in)) { // After the preface is read, it is time to hand over control to the post initialized // decoder. byteDecoder = new FrameDecoder(); byteDecoder.decode(ctx, in, out); } } catch (Throwable e) { onException(ctx, e); } } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { // The channel just became active - send the connection preface to the remote endpoint. sendPreface(ctx); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { cleanup(); super.channelInactive(ctx); } /** Releases the {@code clientPrefaceString}. Any active streams will be left in the open. */ @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { cleanup(); } /** Releases the {@code clientPrefaceString}. Any active streams will be left in the open. */ private void cleanup() { if (clientPrefaceString != null) { clientPrefaceString.release(); clientPrefaceString = null; } } /** * Decodes the client connection preface string from the input buffer. * * @return {@code true} if processing of the client preface string is complete. Since client * preface strings can only be received by servers, returns true immediately for client * endpoints. */ private boolean readClientPrefaceString(ByteBuf in) throws Http2Exception { if (clientPrefaceString == null) { return true; } int prefaceRemaining = clientPrefaceString.readableBytes(); int bytesRead = min(in.readableBytes(), prefaceRemaining); // If the input so far doesn't match the preface, break the connection. if (bytesRead == 0 || !ByteBufUtil.equals( in, in.readerIndex(), clientPrefaceString, clientPrefaceString.readerIndex(), bytesRead)) { String receivedBytes = hexDump( in, in.readerIndex(), min(in.readableBytes(), clientPrefaceString.readableBytes())); throw connectionError( PROTOCOL_ERROR, "HTTP/2 client preface string missing or corrupt. " + "Hex dump for received bytes: %s", receivedBytes); } in.skipBytes(bytesRead); clientPrefaceString.skipBytes(bytesRead); if (!clientPrefaceString.isReadable()) { // Entire preface has been read. clientPrefaceString.release(); clientPrefaceString = null; return true; } return false; } /** * Peeks at that the next frame in the buffer and verifies that it is a {@code SETTINGS} frame. * * @param in the inbound buffer. * @return {@code} true if the next frame is a {@code SETTINGS} frame, {@code false} if more * data is required before we can determine the next frame type. * @throws Http2Exception thrown if the next frame is NOT a {@code SETTINGS} frame. */ private boolean verifyFirstFrameIsSettings(ByteBuf in) throws Http2Exception { if (in.readableBytes() < 4) { // Need more data before we can see the frame type for the first frame. return false; } byte frameType = in.getByte(in.readerIndex() + 3); if (frameType != SETTINGS) { throw connectionError( PROTOCOL_ERROR, "First received frame was not SETTINGS. " + "Hex dump for first 4 bytes: %s", hexDump(in, in.readerIndex(), 4)); } return true; } /** * Sends the HTTP/2 connection preface upon establishment of the connection, if not already * sent. */ private void sendPreface(ChannelHandlerContext ctx) { if (prefaceSent || !ctx.channel().isActive()) { return; } prefaceSent = true; if (!connection().isServer()) { // Clients must send the preface string as the first bytes on the connection. ctx.write(connectionPrefaceBuf()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } // Both client and server must send their initial settings. encoder .writeSettings(ctx, initialSettings(), ctx.newPromise()) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } } private final class FrameDecoder extends BaseDecoder { @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { try { decoder.decodeFrame(ctx, in, out); } catch (Throwable e) { onException(ctx, e); } } } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { // Initialize the encoder and decoder. encoder.lifecycleManager(this); decoder.lifecycleManager(this); byteDecoder = new PrefaceDecoder(ctx); } @Override protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { if (byteDecoder != null) { byteDecoder.handlerRemoved(ctx); byteDecoder = null; } } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { if (byteDecoder == null) { byteDecoder = new PrefaceDecoder(ctx); } byteDecoder.channelActive(ctx); super.channelActive(ctx); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (byteDecoder != null) { byteDecoder.channelInactive(ctx); super.channelInactive(ctx); byteDecoder = null; } } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { byteDecoder.decode(ctx, in, out); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { // Avoid NotYetConnectedException if (!ctx.channel().isActive()) { ctx.close(promise); return; } ChannelFuture future = goAway(ctx, null); ctx.flush(); // If there are no active streams, close immediately after the send is complete. // Otherwise wait until all streams are inactive. if (isGracefulShutdownComplete()) { future.addListener(new ClosingChannelFutureListener(ctx, promise)); } else { closeListener = new ClosingChannelFutureListener(ctx, promise); } } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { // Trigger flush after read on the assumption that flush is cheap if there is nothing to write // and that // for flow-control the read may release window that causes data to be written that can now be // flushed. flush(ctx); } /** * Handles {@link Http2Exception} objects that were thrown from other handlers. Ignores all other * exceptions. */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (getEmbeddedHttp2Exception(cause) != null) { // Some exception in the causality chain is an Http2Exception - handle it. onException(ctx, cause); } else { super.exceptionCaught(ctx, cause); } } /** * Closes the local side of the given stream. If this causes the stream to be closed, adds a hook * to close the channel after the given future completes. * * @param stream the stream to be half closed. * @param future If closing, the future after which to close the channel. */ @Override public void closeStreamLocal(Http2Stream stream, ChannelFuture future) { switch (stream.state()) { case HALF_CLOSED_LOCAL: case OPEN: stream.closeLocalSide(); break; default: closeStream(stream, future); break; } } /** * Closes the remote side of the given stream. If this causes the stream to be closed, adds a hook * to close the channel after the given future completes. * * @param stream the stream to be half closed. * @param future If closing, the future after which to close the channel. */ @Override public void closeStreamRemote(Http2Stream stream, ChannelFuture future) { switch (stream.state()) { case HALF_CLOSED_REMOTE: case OPEN: stream.closeRemoteSide(); break; default: closeStream(stream, future); break; } } @Override public void closeStream(final Http2Stream stream, ChannelFuture future) { stream.close(); if (future.isDone()) { checkCloseConnection(future); } else { future.addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { checkCloseConnection(future); } }); } } /** Central handler for all exceptions caught during HTTP/2 processing. */ @Override public void onException(ChannelHandlerContext ctx, Throwable cause) { Http2Exception embedded = getEmbeddedHttp2Exception(cause); if (isStreamError(embedded)) { onStreamError(ctx, cause, (StreamException) embedded); } else if (embedded instanceof CompositeStreamException) { CompositeStreamException compositException = (CompositeStreamException) embedded; for (StreamException streamException : compositException) { onStreamError(ctx, cause, streamException); } } else { onConnectionError(ctx, cause, embedded); } ctx.flush(); } /** * Called by the graceful shutdown logic to determine when it is safe to close the connection. * Returns {@code true} if the graceful shutdown has completed and the connection can be safely * closed. This implementation just guarantees that there are no active streams. Subclasses may * override to provide additional checks. */ protected boolean isGracefulShutdownComplete() { return connection().numActiveStreams() == 0; } /** * Handler for a connection error. Sends a GO_AWAY frame to the remote endpoint. Once all streams * are closed, the connection is shut down. * * @param ctx the channel context * @param cause the exception that was caught * @param http2Ex the {@link Http2Exception} that is embedded in the causality chain. This may be * {@code null} if it's an unknown exception. */ protected void onConnectionError( ChannelHandlerContext ctx, Throwable cause, Http2Exception http2Ex) { if (http2Ex == null) { http2Ex = new Http2Exception(INTERNAL_ERROR, cause.getMessage(), cause); } goAway(ctx, http2Ex).addListener(new ClosingChannelFutureListener(ctx, ctx.newPromise())); } /** * Handler for a stream error. Sends a {@code RST_STREAM} frame to the remote endpoint and closes * the stream. * * @param ctx the channel context * @param cause the exception that was caught * @param http2Ex the {@link StreamException} that is embedded in the causality chain. */ protected void onStreamError( ChannelHandlerContext ctx, Throwable cause, StreamException http2Ex) { resetStream(ctx, http2Ex.streamId(), http2Ex.error().code(), ctx.newPromise()); } protected Http2FrameWriter frameWriter() { return encoder().frameWriter(); } @Override public ChannelFuture resetStream( final ChannelHandlerContext ctx, int streamId, long errorCode, final ChannelPromise promise) { final Http2Stream stream = connection().stream(streamId); if (stream == null || stream.isResetSent()) { // Don't write a RST_STREAM frame if we are not aware of the stream, or if we have already // written one. return promise.setSuccess(); } ChannelFuture future = frameWriter().writeRstStream(ctx, streamId, errorCode, promise); // Synchronously set the resetSent flag to prevent any subsequent calls // from resulting in multiple reset frames being sent. stream.resetSent(); future.addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { closeStream(stream, promise); } else { // The connection will be closed and so no need to change the resetSent flag to false. onConnectionError(ctx, future.cause(), null); } } }); return future; } @Override public ChannelFuture goAway( final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode, final ByteBuf debugData, ChannelPromise promise) { try { final Http2Connection connection = connection(); if (connection.goAwaySent() && lastStreamId > connection.remote().lastStreamKnownByPeer()) { throw connectionError( PROTOCOL_ERROR, "Last stream identifier must not increase between " + "sending multiple GOAWAY frames (was '%d', is '%d').", connection.remote().lastStreamKnownByPeer(), lastStreamId); } connection.goAwaySent(lastStreamId, errorCode, debugData); // Need to retain before we write the buffer because if we do it after the refCnt could // already be 0 and // result in an IllegalRefCountException. debugData.retain(); ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise); if (future.isDone()) { processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); } else { future.addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); } }); } return future; } catch ( Throwable cause) { // Make sure to catch Throwable because we are doing a retain() in this method. debugData.release(); return promise.setFailure(cause); } } /** * Closes the connection if the graceful shutdown process has completed. * * @param future Represents the status that will be passed to the {@link #closeListener}. */ private void checkCloseConnection(ChannelFuture future) { // If this connection is closing and the graceful shutdown has completed, close the connection // once this operation completes. if (closeListener != null && isGracefulShutdownComplete()) { ChannelFutureListener closeListener = Http2ConnectionHandler.this.closeListener; // This method could be called multiple times // and we don't want to notify the closeListener multiple times. Http2ConnectionHandler.this.closeListener = null; try { closeListener.operationComplete(future); } catch (Exception e) { throw new IllegalStateException("Close listener threw an unexpected exception", e); } } } /** Gets the initial settings to be sent to the remote endpoint. */ private Http2Settings initialSettings() { return initialSettings != null ? initialSettings : decoder.localSettings(); } /** * Close the remote endpoint with with a {@code GO_AWAY} frame. Does <strong>not</strong> flush * immediately, this is the responsibility of the caller. */ private ChannelFuture goAway(ChannelHandlerContext ctx, Http2Exception cause) { long errorCode = cause != null ? cause.error().code() : NO_ERROR.code(); ByteBuf debugData = Http2CodecUtil.toByteBuf(ctx, cause); int lastKnownStream = connection().remote().lastStreamCreated(); return goAway(ctx, lastKnownStream, errorCode, debugData, ctx.newPromise()); } /** * Returns the client preface string if this is a client connection, otherwise returns {@code * null}. */ private static ByteBuf clientPrefaceString(Http2Connection connection) { return connection.isServer() ? connectionPrefaceBuf() : null; } private static void processGoAwayWriteResult( final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode, final ByteBuf debugData, ChannelFuture future) { try { if (future.isSuccess()) { if (errorCode != NO_ERROR.code()) { if (logger.isDebugEnabled()) { logger.debug( format( "Sent GOAWAY: lastStreamId '%d', errorCode '%d', " + "debugData '%s'. Forcing shutdown of the connection.", lastStreamId, errorCode, debugData.toString(UTF_8)), future.cause()); } ctx.close(); } } else { if (logger.isErrorEnabled()) { logger.error( format( "Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', " + "debugData '%s'. Forcing shutdown of the connection.", lastStreamId, errorCode, debugData.toString(UTF_8)), future.cause()); } ctx.close(); } } finally { // We're done with the debug data now. debugData.release(); } } /** Closes the channel when the future completes. */ private static final class ClosingChannelFutureListener implements ChannelFutureListener { private final ChannelHandlerContext ctx; private final ChannelPromise promise; ClosingChannelFutureListener(ChannelHandlerContext ctx, ChannelPromise promise) { this.ctx = ctx; this.promise = promise; } @Override public void operationComplete(ChannelFuture sentGoAwayFuture) throws Exception { ctx.close(promise); } } }
@RunWith(Parameterized.class) public class SocketSslClientRenegotiateTest extends AbstractSocketTest { private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslClientRenegotiateTest.class); private static final File CERT_FILE; private static final File KEY_FILE; static { SelfSignedCertificate ssc; try { ssc = new SelfSignedCertificate(); } catch (CertificateException e) { throw new Error(e); } CERT_FILE = ssc.certificate(); KEY_FILE = ssc.privateKey(); } @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}") public static Collection<Object[]> data() throws Exception { List<SslContext> serverContexts = new ArrayList<SslContext>(); List<SslContext> clientContexts = new ArrayList<SslContext>(); clientContexts.add(new JdkSslClientContext(CERT_FILE)); boolean hasOpenSsl = OpenSsl.isAvailable(); if (hasOpenSsl) { OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE); context.setRejectRemoteInitiatedRenegotiation(true); serverContexts.add(context); } else { logger.warn( "OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause()); } List<Object[]> params = new ArrayList<Object[]>(); for (SslContext sc : serverContexts) { for (SslContext cc : clientContexts) { for (int i = 0; i < 32; i++) { params.add(new Object[] {sc, cc}); } } } return params; } private final SslContext serverCtx; private final SslContext clientCtx; private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>(); private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>(); private volatile Channel clientChannel; private volatile Channel serverChannel; private volatile SslHandler clientSslHandler; private volatile SslHandler serverSslHandler; private final TestHandler clientHandler = new TestHandler(clientException); private final TestHandler serverHandler = new TestHandler(serverException); public SocketSslClientRenegotiateTest(SslContext serverCtx, SslContext clientCtx) { this.serverCtx = serverCtx; this.clientCtx = clientCtx; } @Test(timeout = 30000) public void testSslRenegotiationRejected() throws Throwable { Assume.assumeTrue(OpenSsl.isAvailable()); run(); } public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb) throws Throwable { reset(); sb.childHandler( new ChannelInitializer<Channel>() { @Override @SuppressWarnings("deprecation") public void initChannel(Channel sch) throws Exception { serverChannel = sch; serverSslHandler = serverCtx.newHandler(sch.alloc()); sch.pipeline().addLast("ssl", serverSslHandler); sch.pipeline().addLast("handler", serverHandler); } }); cb.handler( new ChannelInitializer<Channel>() { @Override @SuppressWarnings("deprecation") public void initChannel(Channel sch) throws Exception { clientChannel = sch; clientSslHandler = clientCtx.newHandler(sch.alloc()); sch.pipeline().addLast("ssl", clientSslHandler); sch.pipeline().addLast("handler", clientHandler); } }); Channel sc = sb.bind().sync().channel(); cb.connect().sync(); Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture(); clientHandshakeFuture.sync(); String renegotiation = "SSL_RSA_WITH_RC4_128_SHA"; clientSslHandler.engine().setEnabledCipherSuites(new String[] {renegotiation}); clientSslHandler.renegotiate().await(); serverChannel.close().awaitUninterruptibly(); clientChannel.close().awaitUninterruptibly(); sc.close().awaitUninterruptibly(); try { if (serverException.get() != null) { throw serverException.get(); } fail(); } catch (DecoderException e) { assertTrue(e.getCause() instanceof SSLHandshakeException); } if (clientException.get() != null) { throw clientException.get(); } } private void reset() { clientException.set(null); serverException.set(null); clientHandler.handshakeCounter = 0; serverHandler.handshakeCounter = 0; clientChannel = null; serverChannel = null; clientSslHandler = null; serverSslHandler = null; } @Sharable private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> { protected final AtomicReference<Throwable> exception; private int handshakeCounter; TestHandler(AtomicReference<Throwable> exception) { this.exception = exception; } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { ctx.flush(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { exception.compareAndSet(null, cause); ctx.close(); } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof SslHandshakeCompletionEvent) { SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt; if (handshakeCounter == 0) { handshakeCounter++; if (handshakeEvt.cause() != null) { logger.warn("Handshake failed:", handshakeEvt.cause()); } assertSame(SslHandshakeCompletionEvent.SUCCESS, evt); } else { if (ctx.channel().parent() == null) { assertTrue(handshakeEvt.cause() instanceof ClosedChannelException); } } } } @Override public void messageReceived(ChannelHandlerContext ctx, ByteBuf in) throws Exception {} } }
/** * {@link io.netty.channel.sctp.SctpChannel} implementation which use non-blocking mode and allows * to read / write {@link SctpMessage}s to the underlying {@link SctpChannel}. * * <p>Be aware that not all operations systems support SCTP. Please refer to the documentation of * your operation system, to understand what you need to do to use it. Also this feature is only * supported on Java 7+. */ public class NioSctpChannel extends AbstractNioMessageChannel implements io.netty.channel.sctp.SctpChannel { private static final ChannelMetadata METADATA = new ChannelMetadata(false); private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSctpChannel.class); private final SctpChannelConfig config; private final NotificationHandler<?> notificationHandler; private static SctpChannel newSctpChannel() { try { return SctpChannel.open(); } catch (IOException e) { throw new ChannelException("Failed to open a sctp channel.", e); } } /** Create a new instance */ public NioSctpChannel() { this(newSctpChannel()); } /** Create a new instance using {@link SctpChannel} */ public NioSctpChannel(SctpChannel sctpChannel) { this(null, sctpChannel); } /** * Create a new instance * * @param parent the {@link Channel} which is the parent of this {@link NioSctpChannel} or {@code * null}. * @param sctpChannel the underlying {@link SctpChannel} */ public NioSctpChannel(Channel parent, SctpChannel sctpChannel) { super(parent, sctpChannel, SelectionKey.OP_READ); try { sctpChannel.configureBlocking(false); config = new NioSctpChannelConfig(this, sctpChannel); notificationHandler = new SctpNotificationHandler(this); } catch (IOException e) { try { sctpChannel.close(); } catch (IOException e2) { if (logger.isWarnEnabled()) { logger.warn("Failed to close a partially initialized sctp channel.", e2); } } throw new ChannelException("Failed to enter non-blocking mode.", e); } } @Override public InetSocketAddress localAddress() { return (InetSocketAddress) super.localAddress(); } @Override public InetSocketAddress remoteAddress() { return (InetSocketAddress) super.remoteAddress(); } @Override public SctpServerChannel parent() { return (SctpServerChannel) super.parent(); } @Override public ChannelMetadata metadata() { return METADATA; } @Override public Association association() { try { return javaChannel().association(); } catch (IOException ignored) { return null; } } @Override public Set<InetSocketAddress> allLocalAddresses() { try { final Set<SocketAddress> allLocalAddresses = javaChannel().getAllLocalAddresses(); final Set<InetSocketAddress> addresses = new LinkedHashSet<InetSocketAddress>(allLocalAddresses.size()); for (SocketAddress socketAddress : allLocalAddresses) { addresses.add((InetSocketAddress) socketAddress); } return addresses; } catch (Throwable ignored) { return Collections.emptySet(); } } @Override public SctpChannelConfig config() { return config; } @Override public Set<InetSocketAddress> allRemoteAddresses() { try { final Set<SocketAddress> allLocalAddresses = javaChannel().getRemoteAddresses(); final Set<InetSocketAddress> addresses = new HashSet<InetSocketAddress>(allLocalAddresses.size()); for (SocketAddress socketAddress : allLocalAddresses) { addresses.add((InetSocketAddress) socketAddress); } return addresses; } catch (Throwable ignored) { return Collections.emptySet(); } } @Override protected SctpChannel javaChannel() { return (SctpChannel) super.javaChannel(); } @Override public boolean isActive() { SctpChannel ch = javaChannel(); return ch.isOpen() && association() != null; } @Override protected SocketAddress localAddress0() { try { Iterator<SocketAddress> i = javaChannel().getAllLocalAddresses().iterator(); if (i.hasNext()) { return i.next(); } } catch (IOException e) { // ignore } return null; } @Override protected SocketAddress remoteAddress0() { try { Iterator<SocketAddress> i = javaChannel().getRemoteAddresses().iterator(); if (i.hasNext()) { return i.next(); } } catch (IOException e) { // ignore } return null; } @Override protected void doBind(SocketAddress localAddress) throws Exception { javaChannel().bind(localAddress); } @Override protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { if (localAddress != null) { javaChannel().bind(localAddress); } boolean success = false; try { boolean connected = javaChannel().connect(remoteAddress); if (!connected) { selectionKey().interestOps(SelectionKey.OP_CONNECT); } success = true; return connected; } finally { if (!success) { doClose(); } } } @Override protected void doFinishConnect() throws Exception { if (!javaChannel().finishConnect()) { throw new Error(); } } @Override protected void doDisconnect() throws Exception { doClose(); } @Override protected void doClose() throws Exception { javaChannel().close(); } @Override protected int doReadMessages(List<Object> buf) throws Exception { SctpChannel ch = javaChannel(); RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); ByteBuf buffer = allocHandle.allocate(config().getAllocator()); boolean free = true; try { ByteBuffer data = buffer.internalNioBuffer(buffer.writerIndex(), buffer.writableBytes()); int pos = data.position(); MessageInfo messageInfo = ch.receive(data, null, notificationHandler); if (messageInfo == null) { return 0; } buf.add( new SctpMessage( messageInfo, buffer.writerIndex(buffer.writerIndex() + data.position() - pos))); free = false; return 1; } catch (Throwable cause) { PlatformDependent.throwException(cause); return -1; } finally { int bytesRead = buffer.readableBytes(); allocHandle.record(bytesRead); if (free) { buffer.release(); } } } @Override protected boolean doWriteMessage(Object msg, ChannelOutboundBuffer in) throws Exception { SctpMessage packet = (SctpMessage) msg; ByteBuf data = packet.content(); int dataLen = data.readableBytes(); if (dataLen == 0) { return true; } ByteBufAllocator alloc = alloc(); boolean needsCopy = data.nioBufferCount() != 1; if (!needsCopy) { if (!data.isDirect() && alloc.isDirectBufferPooled()) { needsCopy = true; } } ByteBuffer nioData; if (!needsCopy) { nioData = data.nioBuffer(); } else { data = alloc.directBuffer(dataLen).writeBytes(data); nioData = data.nioBuffer(); } final MessageInfo mi = MessageInfo.createOutgoing(association(), null, packet.streamIdentifier()); mi.payloadProtocolID(packet.protocolIdentifier()); mi.streamNumber(packet.streamIdentifier()); mi.unordered(packet.isUnordered()); final int writtenBytes = javaChannel().send(nioData, mi); return writtenBytes > 0; } @Override protected final Object filterOutboundMessage(Object msg) throws Exception { if (msg instanceof SctpMessage) { SctpMessage m = (SctpMessage) msg; ByteBuf buf = m.content(); if (buf.isDirect() && buf.nioBufferCount() == 1) { return m; } return new SctpMessage( m.protocolIdentifier(), m.streamIdentifier(), m.isUnordered(), newDirectBuffer(m, buf)); } throw new UnsupportedOperationException( "unsupported message type: " + StringUtil.simpleClassName(msg) + " (expected: " + StringUtil.simpleClassName(SctpMessage.class)); } @Override public ChannelFuture bindAddress(InetAddress localAddress) { return bindAddress(localAddress, newPromise()); } @Override public ChannelFuture bindAddress(final InetAddress localAddress, final ChannelPromise promise) { if (eventLoop().inEventLoop()) { try { javaChannel().bindAddress(localAddress); promise.setSuccess(); } catch (Throwable t) { promise.setFailure(t); } } else { eventLoop() .execute( new Runnable() { @Override public void run() { bindAddress(localAddress, promise); } }); } return promise; } @Override public ChannelFuture unbindAddress(InetAddress localAddress) { return unbindAddress(localAddress, newPromise()); } @Override public ChannelFuture unbindAddress(final InetAddress localAddress, final ChannelPromise promise) { if (eventLoop().inEventLoop()) { try { javaChannel().unbindAddress(localAddress); promise.setSuccess(); } catch (Throwable t) { promise.setFailure(t); } } else { eventLoop() .execute( new Runnable() { @Override public void run() { unbindAddress(localAddress, promise); } }); } return promise; } private final class NioSctpChannelConfig extends DefaultSctpChannelConfig { private NioSctpChannelConfig(NioSctpChannel channel, SctpChannel javaChannel) { super(channel, javaChannel); } @Override protected void autoReadCleared() { setReadPending(false); } } }