Пример #1
0
public class PublishProcesser implements Processer {
  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(PublishProcesser.class);
  private static DisconnectMessage DISCONNECT = new DisconnectMessage();

  public Message proc(Message msg, ChannelHandlerContext ctx) {
    String clientId = MemPool.getClientId(ctx.channel());
    if (clientId == null) {
      return DISCONNECT;
    }

    PublishMessage pm = (PublishMessage) msg;
    Set<ChannelEntity> channelEntitys = MemPool.getChannelByTopics(pm.getTopic());
    if (channelEntitys == null) {
      return null;
    }

    for (ChannelEntity channelEntity : channelEntitys) {
      logger.debug(
          "PUBLISH to ChannelEntity topic = "
              + pm.getTopic()
              + " payload = "
              + pm.getDataAsString());
      channelEntity.write(pm);
    }

    return null;
  }
}
Пример #2
0
 /**
  * 系统参数配置
  *
  * @throws Exception
  */
 public void initSystem() throws Exception {
   PropertyConfigurator.configure("Log4j.properties");
   InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
   log.info(System.getProperty("file.encoding"));
   System.setProperty(
       "io.netty.recycler.maxCapacity.default",
       PropertyUtil.getProperty("io.netty.recycler.maxCapacity.default"));
   System.setProperty("io.netty.leakDetectionLevel", "paranoid");
   DbHelper.init();
 }
/**
 * Abstract base class for {@link EventLoopGroup} implementations that handle their tasks with
 * multiple threads at the same time.
 */
public abstract class MultithreadEventLoopGroup extends MultithreadEventExecutorGroup
    implements EventLoopGroup {

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(MultithreadEventLoopGroup.class);

  private static final int DEFAULT_EVENT_LOOP_THREADS;

  static {
    DEFAULT_EVENT_LOOP_THREADS =
        Math.max(
            1,
            SystemPropertyUtil.getInt(
                "io.netty.eventLoopThreads", Runtime.getRuntime().availableProcessors() * 2));

    if (logger.isDebugEnabled()) {
      logger.debug("-Dio.netty.eventLoopThreads: {}", DEFAULT_EVENT_LOOP_THREADS);
    }
  }

  /**
   * @see {@link MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, Executor,
   *     Object...)}
   */
  protected MultithreadEventLoopGroup(int nEventLoops, Executor executor, Object... args) {
    super(nEventLoops == 0 ? DEFAULT_EVENT_LOOP_THREADS : nEventLoops, executor, args);
  }

  /**
   * @see {@link MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, ExecutorFactory,
   *     Object...)}
   */
  protected MultithreadEventLoopGroup(
      int nEventLoops, ExecutorFactory executorFactory, Object... args) {
    super(nEventLoops == 0 ? DEFAULT_EVENT_LOOP_THREADS : nEventLoops, executorFactory, args);
  }

  @Override
  public EventLoop next() {
    return (EventLoop) super.next();
  }

  @Override
  protected abstract EventLoop newChild(Executor executor, Object... args) throws Exception;

  @Override
  public ChannelFuture register(Channel channel) {
    return next().register(channel);
  }

  @Override
  public ChannelFuture register(Channel channel, ChannelPromise promise) {
    return next().register(channel, promise);
  }
}
/**
 * 单一回复编码类
 *
 * @className SingleResponseEncoder
 * @author jy
 * @date 2015年08月02日
 * @since JDK 1.6
 * @see com.foxinmy.weixin4j.server.response.SingleResponse
 */
public class SingleResponseEncoder extends MessageToMessageEncoder<SingleResponse> {

  private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass());

  @Override
  protected void encode(ChannelHandlerContext ctx, SingleResponse response, List<Object> out)
      throws WeixinException {
    String content = response.toContent();
    ctx.writeAndFlush(HttpUtil.createHttpResponse(content, ServerToolkits.CONTENTTYPE$TEXT_PLAIN));
    logger.debug("encode single response:{}", content);
  }
}
Пример #5
0
/**
 * 微信回复编码类
 *
 * @className WeixinResponseEncoder
 * @author jy
 * @date 2014年11月13日
 * @since JDK 1.6
 * @see <a href="http://mp.weixin.qq.com/wiki/0/61c3a8b9d50ac74f18bdf2e54ddfc4e0.html">加密接入指引</a>
 * @see com.foxinmy.weixin4j.response.WeixinResponse
 */
public class WeixinResponseEncoder extends MessageToMessageEncoder<WeixinResponse> {

  private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass());

  @Override
  protected void encode(ChannelHandlerContext ctx, WeixinResponse response, List<Object> out)
      throws WeixinException {
    WeixinMessageTransfer messageTransfer =
        ctx.channel().attr(ServerToolkits.MESSAGE_TRANSFER_KEY).get();
    EncryptType encryptType = messageTransfer.getEncryptType();
    StringBuilder content = new StringBuilder();
    content.append("<xml>");
    content.append(
        String.format(
            "<ToUserName><![CDATA[%s]]></ToUserName>", messageTransfer.getFromUserName()));
    content.append(
        String.format(
            "<FromUserName><![CDATA[%s]]></FromUserName>", messageTransfer.getToUserName()));
    content.append(
        String.format(
            "<CreateTime><![CDATA[%d]]></CreateTime>", System.currentTimeMillis() / 1000l));
    content.append(String.format("<MsgType><![CDATA[%s]]></MsgType>", response.getMsgType()));
    content.append(response.toContent());
    content.append("</xml>");
    if (encryptType == EncryptType.AES) {
      AesToken aesToken = messageTransfer.getAesToken();
      String nonce = ServerToolkits.generateRandomString(32);
      String timestamp = Long.toString(System.currentTimeMillis() / 1000l);
      String encrtypt =
          MessageUtil.aesEncrypt(aesToken.getWeixinId(), aesToken.getAesKey(), content.toString());
      String msgSignature = MessageUtil.signature(aesToken.getToken(), nonce, timestamp, encrtypt);
      content.delete(0, content.length());
      content.append("<xml>");
      content.append(String.format("<Nonce><![CDATA[%s]]></Nonce>", nonce));
      content.append(String.format("<TimeStamp><![CDATA[%s]]></TimeStamp>", timestamp));
      content.append(String.format("<MsgSignature><![CDATA[%s]]></MsgSignature>", msgSignature));
      content.append(String.format("<Encrypt><![CDATA[%s]]></Encrypt>", encrtypt));
      content.append("</xml>");
    }
    ctx.writeAndFlush(
        HttpUtil.createHttpResponse(
            content.toString(), OK, ServerToolkits.CONTENTTYPE$APPLICATION_XML));
    logger.info("{} encode weixin response:{}", encryptType, content);
  }
}
Пример #6
0
public final class ResourceLeakDetector<T> {
  private static final String PROP_LEVEL = "io.netty.leakDetectionLevel";
  private static final Level DEFAULT_LEVEL = Level.SIMPLE;
  private static Level level;

  public static enum Level {
    DISABLED,
    SIMPLE,
    ADVANCED,
    PARANOID;

    private Level() {}
  }

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(ResourceLeakDetector.class);
  private static final int DEFAULT_SAMPLING_INTERVAL = 113;

  static {
    boolean disabled;
    if (SystemPropertyUtil.get("io.netty.noResourceLeakDetection") != null) {
      boolean disabled = SystemPropertyUtil.getBoolean("io.netty.noResourceLeakDetection", false);
      logger.debug("-Dio.netty.noResourceLeakDetection: {}", Boolean.valueOf(disabled));
      logger.warn(
          "-Dio.netty.noResourceLeakDetection is deprecated. Use '-D{}={}' instead.",
          "io.netty.leakDetectionLevel",
          DEFAULT_LEVEL.name().toLowerCase());
    } else {
      disabled = false;
    }
    Level defaultLevel = disabled ? Level.DISABLED : DEFAULT_LEVEL;
    String levelStr =
        SystemPropertyUtil.get("io.netty.leakDetectionLevel", defaultLevel.name())
            .trim()
            .toUpperCase();
    Level level = DEFAULT_LEVEL;
    for (Level l : EnumSet.allOf(Level.class)) {
      if ((levelStr.equals(l.name())) || (levelStr.equals(String.valueOf(l.ordinal())))) {
        level = l;
      }
    }
    level = level;
    if (logger.isDebugEnabled()) {
      logger.debug("-D{}: {}", "io.netty.leakDetectionLevel", level.name().toLowerCase());
    }
  }

  @Deprecated
  public static void setEnabled(boolean enabled) {
    setLevel(enabled ? Level.SIMPLE : Level.DISABLED);
  }

  public static boolean isEnabled() {
    return getLevel().ordinal() > Level.DISABLED.ordinal();
  }

  public static void setLevel(Level level) {
    if (level == null) {
      throw new NullPointerException("level");
    }
    level = level;
  }

  public static Level getLevel() {
    return level;
  }

  private final ResourceLeakDetector<T>.DefaultResourceLeak head = new DefaultResourceLeak(null);
  private final ResourceLeakDetector<T>.DefaultResourceLeak tail = new DefaultResourceLeak(null);
  private final ReferenceQueue<Object> refQueue = new ReferenceQueue();
  private final ConcurrentMap<String, Boolean> reportedLeaks =
      PlatformDependent.newConcurrentHashMap();
  private final String resourceType;
  private final int samplingInterval;
  private final long maxActive;
  private long active;
  private final AtomicBoolean loggedTooManyActive = new AtomicBoolean();
  private long leakCheckCnt;

  public ResourceLeakDetector(Class<?> resourceType) {
    this(StringUtil.simpleClassName(resourceType));
  }

  public ResourceLeakDetector(String resourceType) {
    this(resourceType, 113, Long.MAX_VALUE);
  }

  public ResourceLeakDetector(Class<?> resourceType, int samplingInterval, long maxActive) {
    this(StringUtil.simpleClassName(resourceType), samplingInterval, maxActive);
  }

  public ResourceLeakDetector(String resourceType, int samplingInterval, long maxActive) {
    if (resourceType == null) {
      throw new NullPointerException("resourceType");
    }
    if (samplingInterval <= 0) {
      throw new IllegalArgumentException(
          "samplingInterval: " + samplingInterval + " (expected: 1+)");
    }
    if (maxActive <= 0L) {
      throw new IllegalArgumentException("maxActive: " + maxActive + " (expected: 1+)");
    }
    this.resourceType = resourceType;
    this.samplingInterval = samplingInterval;
    this.maxActive = maxActive;

    this.head.next = this.tail;
    this.tail.prev = this.head;
  }

  public ResourceLeak open(T obj) {
    Level level = level;
    if (level == Level.DISABLED) {
      return null;
    }
    if (level.ordinal() < Level.PARANOID.ordinal()) {
      if (this.leakCheckCnt++ % this.samplingInterval == 0L) {
        reportLeak(level);
        return new DefaultResourceLeak(obj);
      }
      return null;
    }
    reportLeak(level);
    return new DefaultResourceLeak(obj);
  }

  private void reportLeak(Level level) {
    if (!logger.isErrorEnabled()) {
      for (; ; ) {
        ResourceLeakDetector<T>.DefaultResourceLeak ref =
            (DefaultResourceLeak) this.refQueue.poll();
        if (ref == null) {
          break;
        }
        ref.close();
      }
      return;
    }
    int samplingInterval = level == Level.PARANOID ? 1 : this.samplingInterval;
    if ((this.active * samplingInterval > this.maxActive)
        && (this.loggedTooManyActive.compareAndSet(false, true))) {
      logger.error(
          "LEAK: You are creating too many "
              + this.resourceType
              + " instances.  "
              + this.resourceType
              + " is a shared resource that must be reused across the JVM,"
              + "so that only a few instances are created.");
    }
    for (; ; ) {
      ResourceLeakDetector<T>.DefaultResourceLeak ref = (DefaultResourceLeak) this.refQueue.poll();
      if (ref == null) {
        break;
      }
      ref.clear();
      if (ref.close()) {
        String records = ref.toString();
        if (this.reportedLeaks.putIfAbsent(records, Boolean.TRUE) == null) {
          if (records.isEmpty()) {
            logger.error(
                "LEAK: {}.release() was not called before it's garbage-collected. Enable advanced leak reporting to find out where the leak occurred. To enable advanced leak reporting, specify the JVM option '-D{}={}' or call {}.setLevel()",
                new Object[] {
                  this.resourceType,
                  "io.netty.leakDetectionLevel",
                  Level.ADVANCED.name().toLowerCase(),
                  StringUtil.simpleClassName(this)
                });
          } else {
            logger.error(
                "LEAK: {}.release() was not called before it's garbage-collected.{}",
                this.resourceType,
                records);
          }
        }
      }
    }
  }

  private final class DefaultResourceLeak extends PhantomReference<Object> implements ResourceLeak {
    private static final int MAX_RECORDS = 4;
    private final String creationRecord;
    private final Deque<String> lastRecords = new ArrayDeque();
    private final AtomicBoolean freed;
    private ResourceLeakDetector<T>.DefaultResourceLeak prev;
    private ResourceLeakDetector<T>.DefaultResourceLeak next;

    DefaultResourceLeak(Object referent) {
      super(referent != null ? ResourceLeakDetector.this.refQueue : null);
      ResourceLeakDetector.Level level;
      if (referent != null) {
        level = ResourceLeakDetector.getLevel();
        if (level.ordinal() >= ResourceLeakDetector.Level.ADVANCED.ordinal()) {
          this.creationRecord = ResourceLeakDetector.newRecord(3);
        } else {
          this.creationRecord = null;
        }
        synchronized (ResourceLeakDetector.this.head) {
          this.prev = ResourceLeakDetector.this.head;
          this.next = ResourceLeakDetector.this.head.next;
          ResourceLeakDetector.this.head.next.prev = this;
          ResourceLeakDetector.this.head.next = this;
          ResourceLeakDetector.access$408(ResourceLeakDetector.this);
        }
        this.freed = new AtomicBoolean();
      } else {
        this.creationRecord = null;
        this.freed = new AtomicBoolean(true);
      }
    }

    public void record() {
      if (this.creationRecord != null) {
        String value = ResourceLeakDetector.newRecord(2);
        synchronized (this.lastRecords) {
          int size = this.lastRecords.size();
          if ((size == 0) || (!((String) this.lastRecords.getLast()).equals(value))) {
            this.lastRecords.add(value);
          }
          if (size > 4) {
            this.lastRecords.removeFirst();
          }
        }
      }
    }

    public boolean close() {
      if (this.freed.compareAndSet(false, true)) {
        synchronized (ResourceLeakDetector.this.head) {
          ResourceLeakDetector.access$410(ResourceLeakDetector.this);
          this.prev.next = this.next;
          this.next.prev = this.prev;
          this.prev = null;
          this.next = null;
        }
        return true;
      }
      return false;
    }

    public String toString() {
      if (this.creationRecord == null) {
        return "";
      }
      Object[] array;
      synchronized (this.lastRecords) {
        array = this.lastRecords.toArray();
      }
      StringBuilder buf = new StringBuilder(16384);
      buf.append(StringUtil.NEWLINE);
      buf.append("Recent access records: ");
      buf.append(array.length);
      buf.append(StringUtil.NEWLINE);
      if (array.length > 0) {
        for (int i = array.length - 1; i >= 0; i--) {
          buf.append('#');
          buf.append(i + 1);
          buf.append(':');
          buf.append(StringUtil.NEWLINE);
          buf.append(array[i]);
        }
      }
      buf.append("Created at:");
      buf.append(StringUtil.NEWLINE);
      buf.append(this.creationRecord);
      buf.setLength(buf.length() - StringUtil.NEWLINE.length());

      return buf.toString();
    }
  }

  private static final String[] STACK_TRACE_ELEMENT_EXCLUSIONS = {
    "io.netty.buffer.AbstractByteBufAllocator.toLeakAwareBuffer("
  };

  static String newRecord(int recordsToSkip) {
    StringBuilder buf = new StringBuilder(4096);
    StackTraceElement[] array = new Throwable().getStackTrace();
    for (StackTraceElement e : array) {
      if (recordsToSkip > 0) {
        recordsToSkip--;
      } else {
        String estr = e.toString();

        boolean excluded = false;
        for (String exclusion : STACK_TRACE_ELEMENT_EXCLUSIONS) {
          if (estr.startsWith(exclusion)) {
            excluded = true;
            break;
          }
        }
        if (!excluded) {
          buf.append('\t');
          buf.append(estr);
          buf.append(StringUtil.NEWLINE);
        }
      }
    }
    return buf.toString();
  }
}
 @Override
 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
   InternalLoggerFactory.getInstance(SocketConnectionAttemptTest.class)
       .warn("Unexpected exception:", cause);
 }
 static {
   InternalLogger logger = InternalLoggerFactory.getInstance(SocketConnectionAttemptTest.class);
   logger.debug("-Dio.netty.testsuite.badHost: {}", BAD_HOST);
   logger.debug("-Dio.netty.testsuite.badPort: {}", BAD_PORT);
 }
/** A {@link SocketChannel} which is using Old-Blocking-IO */
public class OioSocketChannel extends OioByteStreamChannel implements SocketChannel {

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(OioSocketChannel.class);

  private final Socket socket;
  private final OioSocketChannelConfig config;

  /** Create a new instance with an new {@link Socket} */
  public OioSocketChannel(EventLoop eventLoop) {
    this(eventLoop, new Socket());
  }

  /**
   * Create a new instance from the given {@link Socket}
   *
   * @param socket the {@link Socket} which is used by this instance
   */
  public OioSocketChannel(EventLoop eventLoop, Socket socket) {
    this(null, eventLoop, socket);
  }

  /**
   * Create a new instance from the given {@link Socket}
   *
   * @param parent the parent {@link Channel} which was used to create this instance. This can be
   *     null if the {@link} has no parent as it was created by your self.
   * @param socket the {@link Socket} which is used by this instance
   */
  public OioSocketChannel(Channel parent, EventLoop eventLoop, Socket socket) {
    super(parent, eventLoop);
    this.socket = socket;
    config = new DefaultOioSocketChannelConfig(this, socket);

    boolean success = false;
    try {
      if (socket.isConnected()) {
        activate(socket.getInputStream(), socket.getOutputStream());
      }
      socket.setSoTimeout(SO_TIMEOUT);
      success = true;
    } catch (Exception e) {
      throw new ChannelException("failed to initialize a socket", e);
    } finally {
      if (!success) {
        try {
          socket.close();
        } catch (IOException e) {
          logger.warn("Failed to close a socket.", e);
        }
      }
    }
  }

  @Override
  public ServerSocketChannel parent() {
    return (ServerSocketChannel) super.parent();
  }

  @Override
  public OioSocketChannelConfig config() {
    return config;
  }

  @Override
  public boolean isOpen() {
    return !socket.isClosed();
  }

  @Override
  public boolean isActive() {
    return !socket.isClosed() && socket.isConnected();
  }

  @Override
  public boolean isInputShutdown() {
    return super.isInputShutdown();
  }

  @Override
  public boolean isOutputShutdown() {
    return socket.isOutputShutdown() || !isActive();
  }

  @Override
  public ChannelFuture shutdownOutput() {
    return shutdownOutput(newPromise());
  }

  @Override
  protected int doReadBytes(ByteBuf buf) throws Exception {
    if (socket.isClosed()) {
      return -1;
    }
    try {
      return super.doReadBytes(buf);
    } catch (SocketTimeoutException e) {
      return 0;
    }
  }

  @Override
  public ChannelFuture shutdownOutput(final ChannelPromise future) {
    EventLoop loop = eventLoop();
    if (loop.inEventLoop()) {
      try {
        socket.shutdownOutput();
        future.setSuccess();
      } catch (Throwable t) {
        future.setFailure(t);
      }
    } else {
      loop.execute(
          new Runnable() {
            @Override
            public void run() {
              shutdownOutput(future);
            }
          });
    }
    return future;
  }

  @Override
  public InetSocketAddress localAddress() {
    return (InetSocketAddress) super.localAddress();
  }

  @Override
  public InetSocketAddress remoteAddress() {
    return (InetSocketAddress) super.remoteAddress();
  }

  @Override
  protected SocketAddress localAddress0() {
    return socket.getLocalSocketAddress();
  }

  @Override
  protected SocketAddress remoteAddress0() {
    return socket.getRemoteSocketAddress();
  }

  @Override
  protected void doBind(SocketAddress localAddress) throws Exception {
    socket.bind(localAddress);
  }

  @Override
  protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress)
      throws Exception {
    if (localAddress != null) {
      socket.bind(localAddress);
    }

    boolean success = false;
    try {
      socket.connect(remoteAddress, config().getConnectTimeoutMillis());
      activate(socket.getInputStream(), socket.getOutputStream());
      success = true;
    } catch (SocketTimeoutException e) {
      ConnectTimeoutException cause =
          new ConnectTimeoutException("connection timed out: " + remoteAddress);
      cause.setStackTrace(e.getStackTrace());
      throw cause;
    } finally {
      if (!success) {
        doClose();
      }
    }
  }

  @Override
  protected void doDisconnect() throws Exception {
    doClose();
  }

  @Override
  protected void doClose() throws Exception {
    socket.close();
  }

  @Override
  protected boolean checkInputShutdown() {
    if (isInputShutdown()) {
      try {
        Thread.sleep(config().getSoTimeout());
      } catch (Throwable e) {
        // ignore
      }
      return true;
    }
    return false;
  }
}
 static {
   ResourceLeakDetector.setLevel(Level.PARANOID);
   InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
   // need to create a test logger that searches the log output for leaks
   // and fail the test and kill the world
 }
Пример #11
0
/**
 * Adds <a href="http://en.wikipedia.org/wiki/Transport_Layer_Security">SSL &middot; TLS</a> and
 * StartTLS support to a {@link Channel}. Please refer to the <strong>"SecureChat"</strong> example
 * in the distribution or the web site for the detailed usage.
 *
 * <h3>Beginning the handshake</h3>
 *
 * <p>You must make sure not to write a message while the handshake is in progress unless you are
 * renegotiating. You will be notified by the {@link Future} which is returned by the {@link
 * #handshakeFuture()} method when the handshake process succeeds or fails.
 *
 * <p>Beside using the handshake {@link ChannelFuture} to get notified about the completation of the
 * handshake it's also possible to detect it by implement the {@link
 * ChannelHandler#userEventTriggered(ChannelHandlerContext, Object)} method and check for a {@link
 * SslHandshakeCompletionEvent}.
 *
 * <h3>Handshake</h3>
 *
 * <p>The handshake will be automaticly issued for you once the {@link Channel} is active and {@link
 * SSLEngine#getUseClientMode()} returns {@code true}. So no need to bother with it by your self.
 *
 * <h3>Closing the session</h3>
 *
 * <p>To close the SSL session, the {@link #close()} method should be called to send the {@code
 * close_notify} message to the remote peer. One exception is when you close the {@link Channel} -
 * {@link SslHandler} intercepts the close request and send the {@code close_notify} message before
 * the channel closure automatically. Once the SSL session is closed, it is not reusable, and
 * consequently you should create a new {@link SslHandler} with a new {@link SSLEngine} as explained
 * in the following section.
 *
 * <h3>Restarting the session</h3>
 *
 * <p>To restart the SSL session, you must remove the existing closed {@link SslHandler} from the
 * {@link ChannelPipeline}, insert a new {@link SslHandler} with a new {@link SSLEngine} into the
 * pipeline, and start the handshake process as described in the first section.
 *
 * <h3>Implementing StartTLS</h3>
 *
 * <p><a href="http://en.wikipedia.org/wiki/STARTTLS">StartTLS</a> is the communication pattern that
 * secures the wire in the middle of the plaintext connection. Please note that it is different from
 * SSL &middot; TLS, that secures the wire from the beginning of the connection. Typically, StartTLS
 * is composed of three steps:
 *
 * <ol>
 *   <li>Client sends a StartTLS request to server.
 *   <li>Server sends a StartTLS response to client.
 *   <li>Client begins SSL handshake.
 * </ol>
 *
 * If you implement a server, you need to:
 *
 * <ol>
 *   <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code true},
 *   <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and
 *   <li>write a StartTLS response.
 * </ol>
 *
 * Please note that you must insert {@link SslHandler} <em>before</em> sending the StartTLS
 * response. Otherwise the client can send begin SSL handshake before {@link SslHandler} is inserted
 * to the {@link ChannelPipeline}, causing data corruption.
 *
 * <p>The client-side implementation is much simpler.
 *
 * <ol>
 *   <li>Write a StartTLS request,
 *   <li>wait for the StartTLS response,
 *   <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code false},
 *   <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and
 *   <li>Initiate SSL handshake.
 * </ol>
 *
 * <h3>Known issues</h3>
 *
 * <p>Because of a known issue with the current implementation of the SslEngine that comes with Java
 * it may be possible that you see blocked IO-Threads while a full GC is done.
 *
 * <p>So if you are affected you can workaround this problem by adjust the cache settings like shown
 * below:
 *
 * <pre>
 *     SslContext context = ...;
 *     context.getServerSessionContext().setSessionCacheSize(someSaneSize);
 *     context.getServerSessionContext().setSessionTime(someSameTimeout);
 * </pre>
 *
 * <p>What values to use here depends on the nature of your application and should be set based on
 * monitoring and debugging of it. For more details see <a
 * href="https://github.com/netty/netty/issues/832">#832</a> in our issue tracker.
 */
public class SslHandler extends ByteToMessageDecoder {

  private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslHandler.class);

  private static final Pattern IGNORABLE_CLASS_IN_STACK =
      Pattern.compile("^.*(?:Socket|Datagram|Sctp|Udt)Channel.*$");
  private static final Pattern IGNORABLE_ERROR_MESSAGE =
      Pattern.compile(
          "^.*(?:connection.*(?:reset|closed|abort|broken)|broken.*pipe).*$",
          Pattern.CASE_INSENSITIVE);

  /**
   * Used in {@link #unwrapNonAppData(ChannelHandlerContext)} as input for {@link
   * #unwrap(ChannelHandlerContext, ByteBuf, int, int)}. Using this static instance reduce object
   * creation as {@link Unpooled#EMPTY_BUFFER#nioBuffer()} creates a new {@link ByteBuffer}
   * everytime.
   */
  private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already");

  private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out");
  private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException();

  static {
    SSLENGINE_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
    HANDSHAKE_TIMED_OUT.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
    CHANNEL_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
  }

  private volatile ChannelHandlerContext ctx;
  private final SSLEngine engine;
  private final int maxPacketBufferSize;

  /**
   * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link
   * SSLEngine#unwrap(ByteBuffer, ByteBuffer[])} should be called with a {@link ByteBuf} that is
   * only backed by one {@link ByteBuffer} to reduce the object creation.
   */
  private final ByteBuffer[] singleBuffer = new ByteBuffer[1];

  // BEGIN Platform-dependent flags

  /** {@code true} if and only if {@link SSLEngine} expects a direct buffer. */
  private final boolean wantsDirectBuffer;
  /**
   * {@code true} if and only if {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} requires the output
   * buffer to be always as large as {@link #maxPacketBufferSize} even if the input buffer contains
   * small amount of data.
   *
   * <p>If this flag is {@code false}, we allocate a smaller output buffer.
   */
  private final boolean wantsLargeOutboundNetworkBuffer;
  /**
   * {@code true} if and only if {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} expects a heap
   * buffer rather than a direct buffer. For an unknown reason, JDK8 SSLEngine causes JVM to crash
   * when its cipher suite uses Galois Counter Mode (GCM).
   */
  private boolean wantsInboundHeapBuffer;

  // END Platform-dependent flags

  private final boolean startTls;
  private boolean sentFirstMessage;
  private boolean flushedBeforeHandshake;
  private boolean readDuringHandshake;
  private PendingWriteQueue pendingUnencryptedWrites;

  private Promise<Channel> handshakePromise = new LazyChannelPromise();
  private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();

  /**
   * Set by wrap*() methods when something is produced. {@link
   * #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call
   * ctx.flush().
   */
  private boolean needsFlush;

  private int packetLength;

  /**
   * This flag is used to determine if we need to call {@link ChannelHandlerContext#read()} to
   * consume more data when {@link ChannelConfig#isAutoRead()} is {@code false}.
   */
  private boolean firedChannelRead;

  private volatile long handshakeTimeoutMillis = 10000;
  private volatile long closeNotifyTimeoutMillis = 3000;

  /**
   * Creates a new instance.
   *
   * @param engine the {@link SSLEngine} this handler will use
   */
  public SslHandler(SSLEngine engine) {
    this(engine, false);
  }

  /**
   * Creates a new instance.
   *
   * @param engine the {@link SSLEngine} this handler will use
   * @param startTls {@code true} if the first write request shouldn't be encrypted by the {@link
   *     SSLEngine}
   */
  public SslHandler(SSLEngine engine, boolean startTls) {
    if (engine == null) {
      throw new NullPointerException("engine");
    }
    this.engine = engine;
    this.startTls = startTls;
    maxPacketBufferSize = engine.getSession().getPacketBufferSize();

    boolean opensslEngine = engine instanceof OpenSslEngine;
    wantsDirectBuffer = opensslEngine;
    wantsLargeOutboundNetworkBuffer = !opensslEngine;

    /**
     * When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with
     * one {@link ByteBuffer}.
     *
     * <p>When using {@link OpenSslEngine}, we can use {@link #COMPOSITE_CUMULATOR} because it has
     * {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} which works with multiple {@link
     * ByteBuffer}s and which does not need to do extra memory copies.
     */
    setCumulator(opensslEngine ? COMPOSITE_CUMULATOR : MERGE_CUMULATOR);
  }

  public long getHandshakeTimeoutMillis() {
    return handshakeTimeoutMillis;
  }

  public void setHandshakeTimeout(long handshakeTimeout, TimeUnit unit) {
    if (unit == null) {
      throw new NullPointerException("unit");
    }

    setHandshakeTimeoutMillis(unit.toMillis(handshakeTimeout));
  }

  public void setHandshakeTimeoutMillis(long handshakeTimeoutMillis) {
    if (handshakeTimeoutMillis < 0) {
      throw new IllegalArgumentException(
          "handshakeTimeoutMillis: " + handshakeTimeoutMillis + " (expected: >= 0)");
    }
    this.handshakeTimeoutMillis = handshakeTimeoutMillis;
  }

  public long getCloseNotifyTimeoutMillis() {
    return closeNotifyTimeoutMillis;
  }

  public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) {
    if (unit == null) {
      throw new NullPointerException("unit");
    }

    setCloseNotifyTimeoutMillis(unit.toMillis(closeNotifyTimeout));
  }

  public void setCloseNotifyTimeoutMillis(long closeNotifyTimeoutMillis) {
    if (closeNotifyTimeoutMillis < 0) {
      throw new IllegalArgumentException(
          "closeNotifyTimeoutMillis: " + closeNotifyTimeoutMillis + " (expected: >= 0)");
    }
    this.closeNotifyTimeoutMillis = closeNotifyTimeoutMillis;
  }

  /** Returns the {@link SSLEngine} which is used by this handler. */
  public SSLEngine engine() {
    return engine;
  }

  /**
   * Returns the name of the current application-level protocol.
   *
   * @return the protocol name or {@code null} if application-level protocol has not been negotiated
   */
  public String applicationProtocol() {
    SSLSession sess = engine().getSession();
    if (!(sess instanceof ApplicationProtocolAccessor)) {
      return null;
    }

    return ((ApplicationProtocolAccessor) sess).getApplicationProtocol();
  }

  /**
   * Returns a {@link Future} that will get notified once the current TLS handshake completes.
   *
   * @return the {@link Future} for the iniital TLS handshake if {@link #renegotiate()} was not
   *     invoked. The {@link Future} for the most recent {@linkplain #renegotiate() TLS
   *     renegotiation} otherwise.
   */
  public Future<Channel> handshakeFuture() {
    return handshakePromise;
  }

  /**
   * Sends an SSL {@code close_notify} message to the specified channel and destroys the underlying
   * {@link SSLEngine}.
   */
  public ChannelFuture close() {
    return close(ctx.newPromise());
  }

  /** See {@link #close()} */
  public ChannelFuture close(final ChannelPromise future) {
    final ChannelHandlerContext ctx = this.ctx;
    ctx.executor()
        .execute(
            new Runnable() {
              @Override
              public void run() {
                engine.closeOutbound();
                try {
                  write(ctx, Unpooled.EMPTY_BUFFER, future);
                  flush(ctx);
                } catch (Exception e) {
                  if (!future.tryFailure(e)) {
                    logger.warn("{} flush() raised a masked exception.", ctx.channel(), e);
                  }
                }
              }
            });

    return future;
  }

  /**
   * Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is
   * closed.
   *
   * <p>This method will return the same {@link Future} all the time.
   *
   * @see SSLEngine
   */
  public Future<Channel> sslCloseFuture() {
    return sslCloseFuture;
  }

  @Override
  public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
    if (!pendingUnencryptedWrites.isEmpty()) {
      // Check if queue is not empty first because create a new ChannelException is expensive
      pendingUnencryptedWrites.removeAndFailAll(
          new ChannelException("Pending write on removal of SslHandler"));
    }
  }

  @Override
  public void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise)
      throws Exception {
    closeOutboundAndChannel(ctx, promise, true);
  }

  @Override
  public void close(final ChannelHandlerContext ctx, final ChannelPromise promise)
      throws Exception {
    closeOutboundAndChannel(ctx, promise, false);
  }

  @Override
  public void read(ChannelHandlerContext ctx) throws Exception {
    if (!handshakePromise.isDone()) {
      readDuringHandshake = true;
    }

    ctx.read();
  }

  @Override
  public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
      throws Exception {
    if (!(msg instanceof ByteBuf)) {
      promise.setFailure(new UnsupportedMessageTypeException(msg, ByteBuf.class));
      return;
    }
    pendingUnencryptedWrites.add(msg, promise);
  }

  @Override
  public void flush(ChannelHandlerContext ctx) throws Exception {
    // Do not encrypt the first write request if this handler is
    // created with startTLS flag turned on.
    if (startTls && !sentFirstMessage) {
      sentFirstMessage = true;
      pendingUnencryptedWrites.removeAndWriteAll();
      ctx.flush();
      return;
    }
    if (pendingUnencryptedWrites.isEmpty()) {
      // It's important to NOT use a voidPromise here as the user
      // may want to add a ChannelFutureListener to the ChannelPromise later.
      //
      // See https://github.com/netty/netty/issues/3364
      pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, ctx.newPromise());
    }
    if (!handshakePromise.isDone()) {
      flushedBeforeHandshake = true;
    }
    wrap(ctx, false);
    ctx.flush();
  }

  private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
    ByteBuf out = null;
    ChannelPromise promise = null;
    ByteBufAllocator alloc = ctx.alloc();
    try {
      for (; ; ) {
        Object msg = pendingUnencryptedWrites.current();
        if (msg == null) {
          break;
        }

        ByteBuf buf = (ByteBuf) msg;
        if (out == null) {
          out = allocateOutNetBuf(ctx, buf.readableBytes());
        }

        SSLEngineResult result = wrap(alloc, engine, buf, out);
        if (!buf.isReadable()) {
          promise = pendingUnencryptedWrites.remove();
        } else {
          promise = null;
        }

        if (result.getStatus() == Status.CLOSED) {
          // SSLEngine has been closed already.
          // Any further write attempts should be denied.
          pendingUnencryptedWrites.removeAndFailAll(SSLENGINE_CLOSED);
          return;
        } else {
          switch (result.getHandshakeStatus()) {
            case NEED_TASK:
              runDelegatedTasks();
              break;
            case FINISHED:
              setHandshakeSuccess();
              // deliberate fall-through
            case NOT_HANDSHAKING:
              setHandshakeSuccessIfStillHandshaking();
              // deliberate fall-through
            case NEED_WRAP:
              finishWrap(ctx, out, promise, inUnwrap);
              promise = null;
              out = null;
              break;
            case NEED_UNWRAP:
              return;
            default:
              throw new IllegalStateException(
                  "Unknown handshake status: " + result.getHandshakeStatus());
          }
        }
      }
    } catch (SSLException e) {
      setHandshakeFailure(ctx, e);
      throw e;
    } finally {
      finishWrap(ctx, out, promise, inUnwrap);
    }
  }

  private void finishWrap(
      ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) {
    if (out == null) {
      out = Unpooled.EMPTY_BUFFER;
    } else if (!out.isReadable()) {
      out.release();
      out = Unpooled.EMPTY_BUFFER;
    }

    if (promise != null) {
      ctx.write(out, promise);
    } else {
      ctx.write(out);
    }

    if (inUnwrap) {
      needsFlush = true;
    }
  }

  private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
    ByteBuf out = null;
    ByteBufAllocator alloc = ctx.alloc();
    try {
      for (; ; ) {
        if (out == null) {
          out = allocateOutNetBuf(ctx, 0);
        }
        SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);

        if (result.bytesProduced() > 0) {
          ctx.write(out);
          if (inUnwrap) {
            needsFlush = true;
          }
          out = null;
        }

        switch (result.getHandshakeStatus()) {
          case FINISHED:
            setHandshakeSuccess();
            break;
          case NEED_TASK:
            runDelegatedTasks();
            break;
          case NEED_UNWRAP:
            if (!inUnwrap) {
              unwrapNonAppData(ctx);
            }
            break;
          case NEED_WRAP:
            break;
          case NOT_HANDSHAKING:
            setHandshakeSuccessIfStillHandshaking();
            // Workaround for TLS False Start problem reported at:
            // https://github.com/netty/netty/issues/1108#issuecomment-14266970
            if (!inUnwrap) {
              unwrapNonAppData(ctx);
            }
            break;
          default:
            throw new IllegalStateException(
                "Unknown handshake status: " + result.getHandshakeStatus());
        }

        if (result.bytesProduced() == 0) {
          break;
        }

        // It should not consume empty buffers when it is not handshaking
        // Fix for Android, where it was encrypting empty buffers even when not handshaking
        if (result.bytesConsumed() == 0
            && result.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) {
          break;
        }
      }
    } catch (SSLException e) {
      setHandshakeFailure(ctx, e);
      throw e;
    } finally {
      if (out != null) {
        out.release();
      }
    }
  }

  private SSLEngineResult wrap(ByteBufAllocator alloc, SSLEngine engine, ByteBuf in, ByteBuf out)
      throws SSLException {
    ByteBuf newDirectIn = null;
    try {
      int readerIndex = in.readerIndex();
      int readableBytes = in.readableBytes();

      // We will call SslEngine.wrap(ByteBuffer[], ByteBuffer) to allow efficient handling of
      // CompositeByteBuf without force an extra memory copy when CompositeByteBuffer.nioBuffer() is
      // called.
      final ByteBuffer[] in0;
      if (in.isDirect() || !wantsDirectBuffer) {
        // As CompositeByteBuf.nioBufferCount() can be expensive (as it needs to check all composed
        // ByteBuf
        // to calculate the count) we will just assume a CompositeByteBuf contains more then 1
        // ByteBuf.
        // The worst that can happen is that we allocate an extra ByteBuffer[] in
        // CompositeByteBuf.nioBuffers()
        // which is better then walking the composed ByteBuf in most cases.
        if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) {
          in0 = singleBuffer;
          // We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object
          // allocation
          // to a minimum.
          in0[0] = in.internalNioBuffer(readerIndex, readableBytes);
        } else {
          in0 = in.nioBuffers();
        }
      } else {
        // We could even go further here and check if its a CompositeByteBuf and if so try to
        // decompose it and
        // only replace the ByteBuffer that are not direct. At the moment we just will replace the
        // whole
        // CompositeByteBuf to keep the complexity to a minimum
        newDirectIn = alloc.directBuffer(readableBytes);
        newDirectIn.writeBytes(in, readerIndex, readableBytes);
        in0 = singleBuffer;
        in0[0] = newDirectIn.internalNioBuffer(0, readableBytes);
      }

      for (; ; ) {
        ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
        SSLEngineResult result = engine.wrap(in0, out0);
        in.skipBytes(result.bytesConsumed());
        out.writerIndex(out.writerIndex() + result.bytesProduced());

        switch (result.getStatus()) {
          case BUFFER_OVERFLOW:
            out.ensureWritable(maxPacketBufferSize);
            break;
          default:
            return result;
        }
      }
    } finally {
      // Null out to allow GC of ByteBuffer
      singleBuffer[0] = null;

      if (newDirectIn != null) {
        newDirectIn.release();
      }
    }
  }

  @Override
  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
    // Make sure to release SSLEngine,
    // and notify the handshake future if the connection has been closed during handshake.
    setHandshakeFailure(ctx, CHANNEL_CLOSED);
    super.channelInactive(ctx);
  }

  @Override
  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    if (ignoreException(cause)) {
      // It is safe to ignore the 'connection reset by peer' or
      // 'broken pipe' error after sending close_notify.
      if (logger.isDebugEnabled()) {
        logger.debug(
            "{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred "
                + "while writing close_notify in response to the peer's close_notify",
            ctx.channel(),
            cause);
      }

      // Close the connection explicitly just in case the transport
      // did not close the connection automatically.
      if (ctx.channel().isActive()) {
        ctx.close();
      }
    } else {
      ctx.fireExceptionCaught(cause);
    }
  }

  /**
   * Checks if the given {@link Throwable} can be ignore and just "swallowed"
   *
   * <p>When an ssl connection is closed a close_notify message is sent. After that the peer also
   * sends close_notify however, it's not mandatory to receive the close_notify. The party who sent
   * the initial close_notify can close the connection immediately then the peer will get connection
   * reset error.
   */
  private boolean ignoreException(Throwable t) {
    if (!(t instanceof SSLException) && t instanceof IOException && sslCloseFuture.isDone()) {
      String message = String.valueOf(t.getMessage()).toLowerCase();

      // first try to match connection reset / broke peer based on the regex. This is the fastest
      // way
      // but may fail on different jdk impls or OS's
      if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) {
        return true;
      }

      // Inspect the StackTraceElements to see if it was a connection reset / broken pipe or not
      StackTraceElement[] elements = t.getStackTrace();
      for (StackTraceElement element : elements) {
        String classname = element.getClassName();
        String methodname = element.getMethodName();

        // skip all classes that belong to the io.netty package
        if (classname.startsWith("io.netty.")) {
          continue;
        }

        // check if the method name is read if not skip it
        if (!"read".equals(methodname)) {
          continue;
        }

        // This will also match against SocketInputStream which is used by openjdk 7 and maybe
        // also others
        if (IGNORABLE_CLASS_IN_STACK.matcher(classname).matches()) {
          return true;
        }

        try {
          // No match by now.. Try to load the class via classloader and inspect it.
          // This is mainly done as other JDK implementations may differ in name of
          // the impl.
          Class<?> clazz = PlatformDependent.getClassLoader(getClass()).loadClass(classname);

          if (SocketChannel.class.isAssignableFrom(clazz)
              || DatagramChannel.class.isAssignableFrom(clazz)) {
            return true;
          }

          // also match against SctpChannel via String matching as it may not present.
          if (PlatformDependent.javaVersion() >= 7
              && "com.sun.nio.sctp.SctpChannel".equals(clazz.getSuperclass().getName())) {
            return true;
          }
        } catch (ClassNotFoundException e) {
          // This should not happen just ignore
        }
      }
    }

    return false;
  }

  /**
   * Returns {@code true} if the given {@link ByteBuf} is encrypted. Be aware that this method will
   * not increase the readerIndex of the given {@link ByteBuf}.
   *
   * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to
   *     read, otherwise it will throw an {@link IllegalArgumentException}.
   * @return encrypted {@code true} if the {@link ByteBuf} is encrypted, {@code false} otherwise.
   * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5
   *     bytes to read.
   */
  public static boolean isEncrypted(ByteBuf buffer) {
    if (buffer.readableBytes() < 5) {
      throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
    }
    return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
  }

  /**
   * Return how much bytes can be read out of the encrypted data. Be aware that this method will not
   * increase the readerIndex of the given {@link ByteBuf}.
   *
   * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to
   *     read, otherwise it will throw an {@link IllegalArgumentException}.
   * @return length The length of the encrypted packet that is included in the buffer. This will
   *     return {@code -1} if the given {@link ByteBuf} is not encrypted at all.
   * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5
   *     bytes to read.
   */
  private static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
    int packetLength = 0;

    // SSLv3 or TLS - Check ContentType
    boolean tls;
    switch (buffer.getUnsignedByte(offset)) {
      case 20: // change_cipher_spec
      case 21: // alert
      case 22: // handshake
      case 23: // application_data
        tls = true;
        break;
      default:
        // SSLv2 or bad data
        tls = false;
    }

    if (tls) {
      // SSLv3 or TLS - Check ProtocolVersion
      int majorVersion = buffer.getUnsignedByte(offset + 1);
      if (majorVersion == 3) {
        // SSLv3 or TLS
        packetLength = buffer.getUnsignedShort(offset + 3) + 5;
        if (packetLength <= 5) {
          // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
          tls = false;
        }
      } else {
        // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
        tls = false;
      }
    }

    if (!tls) {
      // SSLv2 or bad data - Check the version
      boolean sslv2 = true;
      int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
      int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
      if (majorVersion == 2 || majorVersion == 3) {
        // SSLv2
        if (headerLength == 2) {
          packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
        } else {
          packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
        }
        if (packetLength <= headerLength) {
          sslv2 = false;
        }
      } else {
        sslv2 = false;
      }

      if (!sslv2) {
        return -1;
      }
    }
    return packetLength;
  }

  @Override
  protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
      throws SSLException {
    final int startOffset = in.readerIndex();
    final int endOffset = in.writerIndex();
    int offset = startOffset;
    int totalLength = 0;

    // If we calculated the length of the current SSL record before, use that information.
    if (packetLength > 0) {
      if (endOffset - startOffset < packetLength) {
        return;
      } else {
        offset += packetLength;
        totalLength = packetLength;
        packetLength = 0;
      }
    }

    boolean nonSslRecord = false;

    while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
      final int readableBytes = endOffset - offset;
      if (readableBytes < 5) {
        break;
      }

      final int packetLength = getEncryptedPacketLength(in, offset);
      if (packetLength == -1) {
        nonSslRecord = true;
        break;
      }

      assert packetLength > 0;

      if (packetLength > readableBytes) {
        // wait until the whole packet can be read
        this.packetLength = packetLength;
        break;
      }

      int newTotalLength = totalLength + packetLength;
      if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
        // Don't read too much.
        break;
      }

      // We have a whole packet.
      // Increment the offset to handle the next packet.
      offset += packetLength;
      totalLength = newTotalLength;
    }

    if (totalLength > 0) {
      boolean decoded = false;

      // The buffer contains one or more full SSL records.
      // Slice out the whole packet so unwrap will only be called with complete packets.
      // Also directly reset the packetLength. This is needed as unwrap(..) may trigger
      // decode(...) again via:
      // 1) unwrap(..) is called
      // 2) wrap(...) is called from within unwrap(...)
      // 3) wrap(...) calls unwrapLater(...)
      // 4) unwrapLater(...) calls decode(...)
      //
      // See https://github.com/netty/netty/issues/1534

      in.skipBytes(totalLength);

      // If SSLEngine expects a heap buffer for unwrapping, do the conversion.
      if (in.isDirect() && wantsInboundHeapBuffer) {
        ByteBuf copy = ctx.alloc().heapBuffer(totalLength);
        try {
          copy.writeBytes(in, startOffset, totalLength);
          decoded = unwrap(ctx, copy, 0, totalLength);
        } finally {
          copy.release();
        }
      } else {
        decoded = unwrap(ctx, in, startOffset, totalLength);
      }

      if (!firedChannelRead) {
        // Check first if firedChannelRead is not set yet as it may have been set in a
        // previous decode(...) call.
        firedChannelRead = decoded;
      }
    }

    if (nonSslRecord) {
      // Not an SSL/TLS packet
      NotSslRecordException e =
          new NotSslRecordException("not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
      in.skipBytes(in.readableBytes());
      ctx.fireExceptionCaught(e);
      setHandshakeFailure(ctx, e);
    }
  }

  @Override
  public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
    // Discard bytes of the cumulation buffer if needed.
    discardSomeReadBytes();

    if (needsFlush) {
      needsFlush = false;
      ctx.flush();
    }

    // If handshake is not finished yet, we need more data.
    if (!ctx.channel().config().isAutoRead() && (!firedChannelRead || !handshakePromise.isDone())) {
      // No auto-read used and no message passed through the ChannelPipeline or the handhshake was
      // not complete
      // yet, which means we need to trigger the read to ensure we not encounter any stalls.
      ctx.read();
    }

    firedChannelRead = false;
    ctx.fireChannelReadComplete();
  }

  /**
   * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle
   * handshakes, etc.
   */
  private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
    unwrap(ctx, Unpooled.EMPTY_BUFFER, 0, 0);
  }

  /** Unwraps inbound SSL records. */
  private boolean unwrap(ChannelHandlerContext ctx, ByteBuf packet, int offset, int length)
      throws SSLException {

    boolean decoded = false;
    boolean wrapLater = false;
    boolean notifyClosure = false;
    ByteBuf decodeOut = allocate(ctx, length);
    try {
      for (; ; ) {
        final SSLEngineResult result = unwrap(engine, packet, offset, length, decodeOut);
        final Status status = result.getStatus();
        final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
        final int produced = result.bytesProduced();
        final int consumed = result.bytesConsumed();

        // Update indexes for the next iteration
        offset += consumed;
        length -= consumed;

        switch (status) {
          case BUFFER_OVERFLOW:
            int readableBytes = decodeOut.readableBytes();
            if (readableBytes > 0) {
              decoded = true;
              ctx.fireChannelRead(decodeOut);
            } else {
              decodeOut.release();
            }
            // Allocate a new buffer which can hold all the rest data and loop again.
            // TODO: We may want to reconsider how we calculate the length here as we may
            // have more then one ssl message to decode.
            decodeOut =
                allocate(ctx, engine.getSession().getApplicationBufferSize() - readableBytes);
            continue;
          case CLOSED:
            // notify about the CLOSED state of the SSLEngine. See #137
            notifyClosure = true;
            break;
          default:
            break;
        }

        switch (handshakeStatus) {
          case NEED_UNWRAP:
            break;
          case NEED_WRAP:
            wrapNonAppData(ctx, true);
            break;
          case NEED_TASK:
            runDelegatedTasks();
            break;
          case FINISHED:
            setHandshakeSuccess();
            wrapLater = true;
            continue;
          case NOT_HANDSHAKING:
            if (setHandshakeSuccessIfStillHandshaking()) {
              wrapLater = true;
              continue;
            }
            if (flushedBeforeHandshake) {
              // We need to call wrap(...) in case there was a flush done before the handshake
              // completed.
              //
              // See https://github.com/netty/netty/pull/2437
              flushedBeforeHandshake = false;
              wrapLater = true;
            }

            break;
          default:
            throw new IllegalStateException("unknown handshake status: " + handshakeStatus);
        }

        if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) {
          break;
        }
      }

      if (wrapLater) {
        wrap(ctx, true);
      }

      if (notifyClosure) {
        sslCloseFuture.trySuccess(ctx.channel());
      }
    } catch (SSLException e) {
      setHandshakeFailure(ctx, e);
      throw e;
    } finally {
      if (decodeOut.isReadable()) {
        decoded = true;

        ctx.fireChannelRead(decodeOut);
      } else {
        decodeOut.release();
      }
    }
    return decoded;
  }

  private SSLEngineResult unwrap(
      SSLEngine engine, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException {
    int nioBufferCount = in.nioBufferCount();
    int writerIndex = out.writerIndex();
    final SSLEngineResult result;
    if (engine instanceof OpenSslEngine && nioBufferCount > 1) {
      /**
       * If {@link OpenSslEngine} is in use, we can use a special {@link
       * OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method that accepts multiple {@link
       * ByteBuffer}s without additional memory copies.
       */
      OpenSslEngine opensslEngine = (OpenSslEngine) engine;
      try {
        singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes());
        result = opensslEngine.unwrap(in.nioBuffers(readerIndex, len), singleBuffer);
        out.writerIndex(writerIndex + result.bytesProduced());
      } finally {
        singleBuffer[0] = null;
      }
    } else {
      result =
          engine.unwrap(
              toByteBuffer(in, readerIndex, len),
              toByteBuffer(out, writerIndex, out.writableBytes()));
    }
    out.writerIndex(writerIndex + result.bytesProduced());
    return result;
  }

  private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) {
    return out.nioBufferCount() == 1
        ? out.internalNioBuffer(index, len)
        : out.nioBuffer(index, len);
  }

  /**
   * Fetches all delegated tasks from the {@link SSLEngine} and runs them by invoking them directly.
   */
  private void runDelegatedTasks() {
    for (; ; ) {
      Runnable task = engine.getDelegatedTask();
      if (task == null) {
        break;
      }

      task.run();
    }
  }

  /**
   * Works around some Android {@link SSLEngine} implementations that skip {@link
   * HandshakeStatus#FINISHED} and go straight into {@link HandshakeStatus#NOT_HANDSHAKING} when
   * handshake is finished.
   *
   * @return {@code true} if and only if the workaround has been applied and thus {@link
   *     #handshakeFuture} has been marked as success by this method
   */
  private boolean setHandshakeSuccessIfStillHandshaking() {
    if (!handshakePromise.isDone()) {
      setHandshakeSuccess();
      return true;
    }
    return false;
  }

  /** Notify all the handshake futures about the successfully handshake */
  private void setHandshakeSuccess() {
    // Work around the JVM crash which occurs when a cipher suite with GCM enabled.
    final String cipherSuite = String.valueOf(engine.getSession().getCipherSuite());
    if (!wantsDirectBuffer && (cipherSuite.contains("_GCM_") || cipherSuite.contains("-GCM-"))) {
      wantsInboundHeapBuffer = true;
    }

    handshakePromise.trySuccess(ctx.channel());

    if (logger.isDebugEnabled()) {
      logger.debug("{} HANDSHAKEN: {}", ctx.channel(), engine.getSession().getCipherSuite());
    }
    ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);

    if (readDuringHandshake && !ctx.channel().config().isAutoRead()) {
      readDuringHandshake = false;
      ctx.read();
    }
  }

  /** Notify all the handshake futures about the failure during the handshake. */
  private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) {
    // Release all resources such as internal buffers that SSLEngine
    // is managing.
    engine.closeOutbound();

    try {
      engine.closeInbound();
    } catch (SSLException e) {
      // only log in debug mode as it most likely harmless and latest chrome still trigger
      // this all the time.
      //
      // See https://github.com/netty/netty/issues/1340
      String msg = e.getMessage();
      if (msg == null || !msg.contains("possible truncation attack")) {
        logger.debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e);
      }
    }
    notifyHandshakeFailure(cause);
    pendingUnencryptedWrites.removeAndFailAll(cause);
  }

  private void notifyHandshakeFailure(Throwable cause) {
    if (handshakePromise.tryFailure(cause)) {
      ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
      ctx.close();
    }
  }

  private void closeOutboundAndChannel(
      final ChannelHandlerContext ctx, final ChannelPromise promise, boolean disconnect)
      throws Exception {
    if (!ctx.channel().isActive()) {
      if (disconnect) {
        ctx.disconnect(promise);
      } else {
        ctx.close(promise);
      }
      return;
    }

    engine.closeOutbound();

    ChannelPromise closeNotifyFuture = ctx.newPromise();
    write(ctx, Unpooled.EMPTY_BUFFER, closeNotifyFuture);
    flush(ctx);
    safeClose(ctx, closeNotifyFuture, promise);
  }

  @Override
  public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
    this.ctx = ctx;
    pendingUnencryptedWrites = new PendingWriteQueue(ctx);

    if (ctx.channel().isActive() && engine.getUseClientMode()) {
      // Begin the initial handshake.
      // channelActive() event has been fired already, which means this.channelActive() will
      // not be invoked. We have to initialize here instead.
      handshake(null);
    } else {
      // channelActive() event has not been fired yet.  this.channelOpen() will be invoked
      // and initialization will occur there.
    }
  }

  /** Performs TLS renegotiation. */
  public Future<Channel> renegotiate() {
    ChannelHandlerContext ctx = this.ctx;
    if (ctx == null) {
      throw new IllegalStateException();
    }

    return renegotiate(ctx.executor().<Channel>newPromise());
  }

  /** Performs TLS renegotiation. */
  public Future<Channel> renegotiate(final Promise<Channel> promise) {
    if (promise == null) {
      throw new NullPointerException("promise");
    }

    ChannelHandlerContext ctx = this.ctx;
    if (ctx == null) {
      throw new IllegalStateException();
    }

    EventExecutor executor = ctx.executor();
    if (!executor.inEventLoop()) {
      executor.execute(
          new OneTimeTask() {
            @Override
            public void run() {
              handshake(promise);
            }
          });
      return promise;
    }

    handshake(promise);
    return promise;
  }

  /**
   * Performs TLS (re)negotiation.
   *
   * @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise},
   *     assuming that the current negotiation has not been finished. Currently, {@code null} is
   *     expected only for the initial handshake.
   */
  private void handshake(final Promise<Channel> newHandshakePromise) {
    final Promise<Channel> p;
    if (newHandshakePromise != null) {
      final Promise<Channel> oldHandshakePromise = handshakePromise;
      if (!oldHandshakePromise.isDone()) {
        // There's no need to handshake because handshake is in progress already.
        // Merge the new promise into the old one.
        oldHandshakePromise.addListener(
            new FutureListener<Channel>() {
              @Override
              public void operationComplete(Future<Channel> future) throws Exception {
                if (future.isSuccess()) {
                  newHandshakePromise.setSuccess(future.getNow());
                } else {
                  newHandshakePromise.setFailure(future.cause());
                }
              }
            });
        return;
      }

      handshakePromise = p = newHandshakePromise;
    } else {
      // Forced to reuse the old handshake.
      p = handshakePromise;
      assert !p.isDone();
    }

    // Begin handshake.
    final ChannelHandlerContext ctx = this.ctx;
    try {
      engine.beginHandshake();
      wrapNonAppData(ctx, false);
      ctx.flush();
    } catch (Exception e) {
      notifyHandshakeFailure(e);
    }

    // Set timeout if necessary.
    final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
    if (handshakeTimeoutMillis <= 0 || p.isDone()) {
      return;
    }

    final ScheduledFuture<?> timeoutFuture =
        ctx.executor()
            .schedule(
                new Runnable() {
                  @Override
                  public void run() {
                    if (p.isDone()) {
                      return;
                    }
                    notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
                  }
                },
                handshakeTimeoutMillis,
                TimeUnit.MILLISECONDS);

    // Cancel the handshake timeout when handshake is finished.
    p.addListener(
        new FutureListener<Channel>() {
          @Override
          public void operationComplete(Future<Channel> f) throws Exception {
            timeoutFuture.cancel(false);
          }
        });
  }

  /** Issues an initial TLS handshake once connected when used in client-mode */
  @Override
  public void channelActive(final ChannelHandlerContext ctx) throws Exception {
    if (!startTls && engine.getUseClientMode()) {
      // Begin the initial handshake
      handshake(null);
    }
    ctx.fireChannelActive();
  }

  private void safeClose(
      final ChannelHandlerContext ctx, ChannelFuture flushFuture, final ChannelPromise promise) {
    if (!ctx.channel().isActive()) {
      ctx.close(promise);
      return;
    }

    final ScheduledFuture<?> timeoutFuture;
    if (closeNotifyTimeoutMillis > 0) {
      // Force-close the connection if close_notify is not fully sent in time.
      timeoutFuture =
          ctx.executor()
              .schedule(
                  new Runnable() {
                    @Override
                    public void run() {
                      logger.warn(
                          "{} Last write attempt timed out; force-closing the connection.",
                          ctx.channel());

                      // We notify the promise in the TryNotifyListener as there is a "race" where
                      // the close(...) call
                      // by the timeoutFuture and the close call in the flushFuture listener will be
                      // called. Because of
                      // this we need to use trySuccess() and tryFailure(...) as otherwise we can
                      // cause an
                      // IllegalStateException.
                      ctx.close(ctx.newPromise()).addListener(new ChannelPromiseNotifier(promise));
                    }
                  },
                  closeNotifyTimeoutMillis,
                  TimeUnit.MILLISECONDS);
    } else {
      timeoutFuture = null;
    }

    // Close the connection if close_notify is sent in time.
    flushFuture.addListener(
        new ChannelFutureListener() {
          @Override
          public void operationComplete(ChannelFuture f) throws Exception {
            if (timeoutFuture != null) {
              timeoutFuture.cancel(false);
            }
            // Trigger the close in all cases to make sure the promise is notified
            // See https://github.com/netty/netty/issues/2358
            //
            // We notify the promise in the ChannelPromiseNotifier as there is a "race" where the
            // close(...) call
            // by the timeoutFuture and the close call in the flushFuture listener will be called.
            // Because of
            // this we need to use trySuccess() and tryFailure(...) as otherwise we can cause an
            // IllegalStateException.
            ctx.close(ctx.newPromise()).addListener(new ChannelPromiseNotifier(promise));
          }
        });
  }

  /**
   * Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies
   * in {@link OpenSslEngine}.
   */
  private ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
    ByteBufAllocator alloc = ctx.alloc();
    if (wantsDirectBuffer) {
      return alloc.directBuffer(capacity);
    } else {
      return alloc.buffer(capacity);
    }
  }

  /**
   * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which
   * can encrypt the specified amount of pending bytes.
   */
  private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) {
    if (wantsLargeOutboundNetworkBuffer) {
      return allocate(ctx, maxPacketBufferSize);
    } else {
      return allocate(
          ctx,
          Math.min(
              pendingBytes + OpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH, maxPacketBufferSize));
    }
  }

  private final class LazyChannelPromise extends DefaultPromise<Channel> {

    @Override
    protected EventExecutor executor() {
      if (ctx == null) {
        throw new IllegalStateException();
      }
      return ctx.executor();
    }

    @Override
    protected void checkDeadLock() {
      if (ctx == null) {
        // If ctx is null the handlerAdded(...) callback was not called, in this case the
        // checkDeadLock()
        // method was called from another Thread then the one that is used by ctx.executor(). We
        // need to
        // guard against this as a user can see a race if handshakeFuture().sync() is called but the
        // handlerAdded(..) method was not yet as it is called from the EventExecutor of the
        // ChannelHandlerContext. If we not guard against this super.checkDeadLock() would cause an
        // IllegalStateException when trying to call executor().
        return;
      }
      super.checkDeadLock();
    }
  }
}
Пример #12
0
/**
 * A {@link Timer} optimized for approximated I/O timeout scheduling.
 *
 * <h3>Tick Duration</h3>
 *
 * As described with 'approximated', this timer does not execute the scheduled {@link TimerTask} on
 * time. {@link HashedWheelTimer}, on every tick, will check if there are any {@link TimerTask}s
 * behind the schedule and execute them.
 *
 * <p>You can increase or decrease the accuracy of the execution timing by specifying smaller or
 * larger tick duration in the constructor. In most network applications, I/O timeout does not need
 * to be accurate. Therefore, the default tick duration is 100 milliseconds and you will not need to
 * try different configurations in most cases.
 *
 * <h3>Ticks per Wheel (Wheel Size)</h3>
 *
 * {@link HashedWheelTimer} maintains a data structure called 'wheel'. To put simply, a wheel is a
 * hash table of {@link TimerTask}s whose hash function is 'dead line of the task'. The default
 * number of ticks per wheel (i.e. the size of the wheel) is 512. You could specify a larger value
 * if you are going to schedule a lot of timeouts.
 *
 * <h3>Do not create many instances.</h3>
 *
 * {@link HashedWheelTimer} creates a new thread whenever it is instantiated and started. Therefore,
 * you should make sure to create only one instance and share it across your application. One of the
 * common mistakes, that makes your application unresponsive, is to create a new instance for every
 * connection.
 *
 * <h3>Implementation Details</h3>
 *
 * {@link HashedWheelTimer} is based on <a href="http://cseweb.ucsd.edu/users/varghese/">George
 * Varghese</a> and Tony Lauck's paper, <a
 * href="http://cseweb.ucsd.edu/users/varghese/PAPERS/twheel.ps.Z">'Hashed and Hierarchical Timing
 * Wheels: data structures to efficiently implement a timer facility'</a>. More comprehensive slides
 * are located <a href="http://www.cse.wustl.edu/~cdgill/courses/cs6874/TimingWheels.ppt">here</a>.
 */
public class HashedWheelTimer implements Timer {

  static final InternalLogger logger = InternalLoggerFactory.getInstance(HashedWheelTimer.class);

  private static final ResourceLeakDetector<HashedWheelTimer> leakDetector =
      new ResourceLeakDetector<HashedWheelTimer>(
          HashedWheelTimer.class, 1, Runtime.getRuntime().availableProcessors() * 4);

  private final ResourceLeak leak = leakDetector.open(this);
  private final Worker worker = new Worker();
  final Thread workerThread;
  final AtomicInteger workerState = new AtomicInteger(); // 0 - init, 1 - started, 2 - shut down

  private final long roundDuration;
  final long tickDuration;
  final Set<HashedWheelTimeout>[] wheel;
  final int mask;
  final ReadWriteLock lock = new ReentrantReadWriteLock();
  volatile int wheelCursor;

  /**
   * Creates a new timer with the default thread factory ({@link Executors#defaultThreadFactory()}),
   * default tick duration, and default number of ticks per wheel.
   */
  public HashedWheelTimer() {
    this(Executors.defaultThreadFactory());
  }

  /**
   * Creates a new timer with the default thread factory ({@link Executors#defaultThreadFactory()})
   * and default number of ticks per wheel.
   *
   * @param tickDuration the duration between tick
   * @param unit the time unit of the {@code tickDuration}
   * @throws NullPointerException if {@code unit} is {@code null}
   * @throws IllegalArgumentException if {@code tickDuration} is <= 0
   */
  public HashedWheelTimer(long tickDuration, TimeUnit unit) {
    this(Executors.defaultThreadFactory(), tickDuration, unit);
  }

  /**
   * Creates a new timer with the default thread factory ({@link Executors#defaultThreadFactory()}).
   *
   * @param tickDuration the duration between tick
   * @param unit the time unit of the {@code tickDuration}
   * @param ticksPerWheel the size of the wheel
   * @throws NullPointerException if {@code unit} is {@code null}
   * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is
   *     <= 0
   */
  public HashedWheelTimer(long tickDuration, TimeUnit unit, int ticksPerWheel) {
    this(Executors.defaultThreadFactory(), tickDuration, unit, ticksPerWheel);
  }

  /**
   * Creates a new timer with the default tick duration and default number of ticks per wheel.
   *
   * @param threadFactory a {@link ThreadFactory} that creates a background {@link Thread} which is
   *     dedicated to {@link TimerTask} execution.
   * @throws NullPointerException if {@code threadFactory} is {@code null}
   */
  public HashedWheelTimer(ThreadFactory threadFactory) {
    this(threadFactory, 100, TimeUnit.MILLISECONDS);
  }

  /**
   * Creates a new timer with the default number of ticks per wheel.
   *
   * @param threadFactory a {@link ThreadFactory} that creates a background {@link Thread} which is
   *     dedicated to {@link TimerTask} execution.
   * @param tickDuration the duration between tick
   * @param unit the time unit of the {@code tickDuration}
   * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code
   *     null}
   * @throws IllegalArgumentException if {@code tickDuration} is <= 0
   */
  public HashedWheelTimer(ThreadFactory threadFactory, long tickDuration, TimeUnit unit) {
    this(threadFactory, tickDuration, unit, 512);
  }

  /**
   * Creates a new timer.
   *
   * @param threadFactory a {@link ThreadFactory} that creates a background {@link Thread} which is
   *     dedicated to {@link TimerTask} execution.
   * @param tickDuration the duration between tick
   * @param unit the time unit of the {@code tickDuration}
   * @param ticksPerWheel the size of the wheel
   * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code
   *     null}
   * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is
   *     <= 0
   */
  public HashedWheelTimer(
      ThreadFactory threadFactory, long tickDuration, TimeUnit unit, int ticksPerWheel) {

    if (threadFactory == null) {
      throw new NullPointerException("threadFactory");
    }
    if (unit == null) {
      throw new NullPointerException("unit");
    }
    if (tickDuration <= 0) {
      throw new IllegalArgumentException("tickDuration must be greater than 0: " + tickDuration);
    }
    if (ticksPerWheel <= 0) {
      throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel);
    }

    // Normalize ticksPerWheel to power of two and initialize the wheel.
    wheel = createWheel(ticksPerWheel);
    mask = wheel.length - 1;

    // Convert tickDuration to milliseconds.
    this.tickDuration = tickDuration = unit.toMillis(tickDuration);

    // Prevent overflow.
    if (tickDuration == Long.MAX_VALUE || tickDuration >= Long.MAX_VALUE / wheel.length) {
      throw new IllegalArgumentException("tickDuration is too long: " + tickDuration + ' ' + unit);
    }

    roundDuration = tickDuration * wheel.length;

    workerThread = threadFactory.newThread(worker);
  }

  @SuppressWarnings("unchecked")
  private static Set<HashedWheelTimeout>[] createWheel(int ticksPerWheel) {
    if (ticksPerWheel <= 0) {
      throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel);
    }
    if (ticksPerWheel > 1073741824) {
      throw new IllegalArgumentException(
          "ticksPerWheel may not be greater than 2^30: " + ticksPerWheel);
    }

    ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel);
    Set<HashedWheelTimeout>[] wheel = new Set[ticksPerWheel];
    for (int i = 0; i < wheel.length; i++) {
      wheel[i] =
          Collections.newSetFromMap(
              PlatformDependent.<HashedWheelTimeout, Boolean>newConcurrentHashMap());
    }
    return wheel;
  }

  private static int normalizeTicksPerWheel(int ticksPerWheel) {
    int normalizedTicksPerWheel = 1;
    while (normalizedTicksPerWheel < ticksPerWheel) {
      normalizedTicksPerWheel <<= 1;
    }
    return normalizedTicksPerWheel;
  }

  /**
   * Starts the background thread explicitly. The background thread will start automatically on
   * demand even if you did not call this method.
   *
   * @throws IllegalStateException if this timer has been {@linkplain #stop() stopped} already
   */
  public void start() {
    switch (workerState.get()) {
      case 0:
        if (workerState.compareAndSet(0, 1)) {
          workerThread.start();
        }
        break;
      case 1:
        break;
      case 2:
        throw new IllegalStateException("cannot be started once stopped");
      default:
        throw new Error();
    }
  }

  @Override
  public Set<Timeout> stop() {
    if (Thread.currentThread() == workerThread) {
      throw new IllegalStateException(
          HashedWheelTimer.class.getSimpleName()
              + ".stop() cannot be called from "
              + TimerTask.class.getSimpleName());
    }

    if (!workerState.compareAndSet(1, 2)) {
      // workerState can be 0 or 2 at this moment - let it always be 2.
      workerState.set(2);
      return Collections.emptySet();
    }

    boolean interrupted = false;
    while (workerThread.isAlive()) {
      workerThread.interrupt();
      try {
        workerThread.join(100);
      } catch (InterruptedException e) {
        interrupted = true;
      }
    }

    if (interrupted) {
      Thread.currentThread().interrupt();
    }

    leak.close();

    Set<Timeout> unprocessedTimeouts = new HashSet<Timeout>();
    for (Set<HashedWheelTimeout> bucket : wheel) {
      unprocessedTimeouts.addAll(bucket);
      bucket.clear();
    }

    return Collections.unmodifiableSet(unprocessedTimeouts);
  }

  @Override
  public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit) {
    final long currentTime = System.currentTimeMillis();

    if (task == null) {
      throw new NullPointerException("task");
    }
    if (unit == null) {
      throw new NullPointerException("unit");
    }

    start();

    delay = unit.toMillis(delay);
    HashedWheelTimeout timeout = new HashedWheelTimeout(task, currentTime + delay);
    scheduleTimeout(timeout, delay);
    return timeout;
  }

  void scheduleTimeout(HashedWheelTimeout timeout, long delay) {
    // delay must be equal to or greater than tickDuration so that the
    // worker thread never misses the timeout.
    if (delay < tickDuration) {
      delay = tickDuration;
    }

    // Prepare the required parameters to schedule the timeout object.
    final long lastRoundDelay = delay % roundDuration;
    final long lastTickDelay = delay % tickDuration;
    final long relativeIndex = lastRoundDelay / tickDuration + (lastTickDelay != 0 ? 1 : 0);

    final long remainingRounds = delay / roundDuration - (delay % roundDuration == 0 ? 1 : 0);

    // Add the timeout to the wheel.
    lock.readLock().lock();
    try {
      int stopIndex = (int) (wheelCursor + relativeIndex & mask);
      timeout.stopIndex = stopIndex;
      timeout.remainingRounds = remainingRounds;

      wheel[stopIndex].add(timeout);
    } finally {
      lock.readLock().unlock();
    }
  }

  private final class Worker implements Runnable {

    private long startTime;
    private long tick;

    Worker() {}

    @Override
    public void run() {
      List<HashedWheelTimeout> expiredTimeouts = new ArrayList<HashedWheelTimeout>();

      startTime = System.currentTimeMillis();
      tick = 1;

      while (workerState.get() == 1) {
        final long deadline = waitForNextTick();
        if (deadline > 0) {
          fetchExpiredTimeouts(expiredTimeouts, deadline);
          notifyExpiredTimeouts(expiredTimeouts);
        }
      }
    }

    private void fetchExpiredTimeouts(List<HashedWheelTimeout> expiredTimeouts, long deadline) {

      // Find the expired timeouts and decrease the round counter
      // if necessary.  Note that we don't send the notification
      // immediately to make sure the listeners are called without
      // an exclusive lock.
      lock.writeLock().lock();
      try {
        int newWheelCursor = wheelCursor = wheelCursor + 1 & mask;
        fetchExpiredTimeouts(expiredTimeouts, wheel[newWheelCursor].iterator(), deadline);
      } finally {
        lock.writeLock().unlock();
      }
    }

    private void fetchExpiredTimeouts(
        List<HashedWheelTimeout> expiredTimeouts, Iterator<HashedWheelTimeout> i, long deadline) {

      List<HashedWheelTimeout> slipped = null;
      while (i.hasNext()) {
        HashedWheelTimeout timeout = i.next();
        if (timeout.remainingRounds <= 0) {
          i.remove();
          if (timeout.deadline <= deadline) {
            expiredTimeouts.add(timeout);
          } else {
            // Handle the case where the timeout is put into a wrong
            // place, usually one tick earlier.  For now, just add
            // it to a temporary list - we will reschedule it in a
            // separate loop.
            if (slipped == null) {
              slipped = new ArrayList<HashedWheelTimeout>();
            }
            slipped.add(timeout);
          }
        } else {
          timeout.remainingRounds--;
        }
      }

      // Reschedule the slipped timeouts.
      if (slipped != null) {
        for (HashedWheelTimeout timeout : slipped) {
          scheduleTimeout(timeout, timeout.deadline - deadline);
        }
      }
    }

    private void notifyExpiredTimeouts(List<HashedWheelTimeout> expiredTimeouts) {
      // Notify the expired timeouts.
      for (int i = expiredTimeouts.size() - 1; i >= 0; i--) {
        expiredTimeouts.get(i).expire();
      }

      // Clean up the temporary list.
      expiredTimeouts.clear();
    }

    private long waitForNextTick() {
      long deadline = startTime + tickDuration * tick;

      for (; ; ) {
        final long currentTime = System.currentTimeMillis();
        long sleepTime = tickDuration * tick - (currentTime - startTime);

        // Check if we run on windows, as if thats the case we will need
        // to round the sleepTime as workaround for a bug that only affect
        // the JVM if it runs on windows.
        //
        // See https://github.com/netty/netty/issues/356
        if (PlatformDependent.isWindows()) {
          sleepTime = sleepTime / 10 * 10;
        }

        if (sleepTime <= 0) {
          break;
        }

        try {
          Thread.sleep(sleepTime);
        } catch (InterruptedException e) {
          if (workerState.get() != 1) {
            return -1;
          }
        }
      }

      // Increase the tick.
      tick++;
      return deadline;
    }
  }

  private final class HashedWheelTimeout implements Timeout {

    private static final int ST_INIT = 0;
    private static final int ST_CANCELLED = 1;
    private static final int ST_EXPIRED = 2;

    private final TimerTask task;
    final long deadline;
    volatile int stopIndex;
    volatile long remainingRounds;
    private final AtomicInteger state = new AtomicInteger(ST_INIT);

    HashedWheelTimeout(TimerTask task, long deadline) {
      this.task = task;
      this.deadline = deadline;
    }

    @Override
    public Timer timer() {
      return HashedWheelTimer.this;
    }

    @Override
    public TimerTask task() {
      return task;
    }

    @Override
    public boolean cancel() {
      if (!state.compareAndSet(ST_INIT, ST_CANCELLED)) {
        return false;
      }

      wheel[stopIndex].remove(this);
      return true;
    }

    @Override
    public boolean isCancelled() {
      return state.get() == ST_CANCELLED;
    }

    @Override
    public boolean isExpired() {
      return state.get() != ST_INIT;
    }

    public void expire() {
      if (!state.compareAndSet(ST_INIT, ST_EXPIRED)) {
        return;
      }

      try {
        task.run(this);
      } catch (Throwable t) {
        if (logger.isWarnEnabled()) {
          logger.warn("An exception was thrown by " + TimerTask.class.getSimpleName() + '.', t);
        }
      }
    }

    @Override
    public String toString() {
      long currentTime = System.currentTimeMillis();
      long remaining = deadline - currentTime;

      StringBuilder buf = new StringBuilder(192);
      buf.append(getClass().getSimpleName());
      buf.append('(');

      buf.append("deadline: ");
      if (remaining > 0) {
        buf.append(remaining);
        buf.append(" ms later, ");
      } else if (remaining < 0) {
        buf.append(-remaining);
        buf.append(" ms ago, ");
      } else {
        buf.append("now, ");
      }

      if (isCancelled()) {
        buf.append(", cancelled");
      }

      return buf.append(')').toString();
    }
  }
}
Пример #13
0
/**
 * A {@link ChannelHandler} that aggregates an {@link HttpMessage} and its following {@link
 * HttpContent}s into a single {@link FullHttpRequest} or {@link FullHttpResponse} (depending on if
 * it used to handle requests or responses) with no following {@link HttpContent}s. It is useful
 * when you don't want to take care of HTTP messages whose transfer encoding is 'chunked'. Insert
 * this handler after {@link HttpObjectDecoder} in the {@link ChannelPipeline}:
 *
 * <pre>
 * {@link ChannelPipeline} p = ...;
 * ...
 * p.addLast("encoder", new {@link HttpResponseEncoder}());
 * p.addLast("decoder", new {@link HttpRequestDecoder}());
 * p.addLast("aggregator", <b>new {@link HttpObjectAggregator}(1048576)</b>);
 * ...
 * p.addLast("handler", new HttpRequestHandler());
 * </pre>
 *
 * Be aware that you need to have the {@link HttpResponseEncoder} or {@link HttpRequestEncoder}
 * before the {@link HttpObjectAggregator} in the {@link ChannelPipeline}.
 */
public class HttpObjectAggregator
    extends MessageAggregator<HttpObject, HttpMessage, HttpContent, FullHttpMessage> {
  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(HttpObjectAggregator.class);
  private static final FullHttpResponse CONTINUE =
      new DefaultFullHttpResponse(
          HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER);
  private static final FullHttpResponse EXPECTATION_FAILED =
      new DefaultFullHttpResponse(
          HttpVersion.HTTP_1_1, HttpResponseStatus.EXPECTATION_FAILED, Unpooled.EMPTY_BUFFER);
  private static final FullHttpResponse TOO_LARGE =
      new DefaultFullHttpResponse(
          HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER);

  static {
    EXPECTATION_FAILED.headers().set(CONTENT_LENGTH, "0");
    TOO_LARGE.headers().set(CONTENT_LENGTH, "0");
  }

  private final boolean closeOnExpectationFailed;

  /**
   * Creates a new instance.
   *
   * @param maxContentLength the maximum length of the aggregated content in bytes. If the length of
   *     the aggregated content exceeds this value, {@link
   *     #handleOversizedMessage(ChannelHandlerContext, HttpMessage)} will be called.
   */
  public HttpObjectAggregator(int maxContentLength) {
    this(maxContentLength, false);
  }

  /**
   * Creates a new instance.
   *
   * @param maxContentLength the maximum length of the aggregated content in bytes. If the length of
   *     the aggregated content exceeds this value, {@link
   *     #handleOversizedMessage(ChannelHandlerContext, HttpMessage)} will be called.
   * @param closeOnExpectationFailed If a 100-continue response is detected but the content length
   *     is too large then {@code true} means close the connection. otherwise the connection will
   *     remain open and data will be consumed and discarded until the next request is received.
   */
  public HttpObjectAggregator(int maxContentLength, boolean closeOnExpectationFailed) {
    super(maxContentLength);
    this.closeOnExpectationFailed = closeOnExpectationFailed;
  }

  @Override
  protected boolean isStartMessage(HttpObject msg) throws Exception {
    return msg instanceof HttpMessage;
  }

  @Override
  protected boolean isContentMessage(HttpObject msg) throws Exception {
    return msg instanceof HttpContent;
  }

  @Override
  protected boolean isLastContentMessage(HttpContent msg) throws Exception {
    return msg instanceof LastHttpContent;
  }

  @Override
  protected boolean isAggregated(HttpObject msg) throws Exception {
    return msg instanceof FullHttpMessage;
  }

  @Override
  protected boolean isContentLengthInvalid(HttpMessage start, int maxContentLength) {
    return getContentLength(start, -1) > maxContentLength;
  }

  @Override
  protected Object newContinueResponse(
      HttpMessage start, int maxContentLength, ChannelPipeline pipeline) {
    if (HttpUtil.is100ContinueExpected(start)) {
      if (getContentLength(start, -1) <= maxContentLength) {
        return CONTINUE.duplicate().retain();
      }

      pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE);
      return EXPECTATION_FAILED.duplicate().retain();
    }
    return null;
  }

  @Override
  protected boolean closeAfterContinueResponse(Object msg) {
    return closeOnExpectationFailed && ignoreContentAfterContinueResponse(msg);
  }

  @Override
  protected boolean ignoreContentAfterContinueResponse(Object msg) {
    return msg instanceof HttpResponse
        && ((HttpResponse) msg).status().code() == HttpResponseStatus.EXPECTATION_FAILED.code();
  }

  @Override
  protected FullHttpMessage beginAggregation(HttpMessage start, ByteBuf content) throws Exception {
    assert !(start instanceof FullHttpMessage);

    HttpUtil.setTransferEncodingChunked(start, false);

    AggregatedFullHttpMessage ret;
    if (start instanceof HttpRequest) {
      ret = new AggregatedFullHttpRequest((HttpRequest) start, content, null);
    } else if (start instanceof HttpResponse) {
      ret = new AggregatedFullHttpResponse((HttpResponse) start, content, null);
    } else {
      throw new Error();
    }
    return ret;
  }

  @Override
  protected void aggregate(FullHttpMessage aggregated, HttpContent content) throws Exception {
    if (content instanceof LastHttpContent) {
      // Merge trailing headers into the message.
      ((AggregatedFullHttpMessage) aggregated)
          .setTrailingHeaders(((LastHttpContent) content).trailingHeaders());
    }
  }

  @Override
  protected void finishAggregation(FullHttpMessage aggregated) throws Exception {
    // Set the 'Content-Length' header. If one isn't already set.
    // This is important as HEAD responses will use a 'Content-Length' header which
    // does not match the actual body, but the number of bytes that would be
    // transmitted if a GET would have been used.
    //
    // See rfc2616 14.13 Content-Length
    if (!HttpUtil.isContentLengthSet(aggregated)) {
      aggregated
          .headers()
          .set(
              HttpHeaderNames.CONTENT_LENGTH, String.valueOf(aggregated.content().readableBytes()));
    }
  }

  @Override
  protected void handleOversizedMessage(final ChannelHandlerContext ctx, HttpMessage oversized)
      throws Exception {
    if (oversized instanceof HttpRequest) {
      // send back a 413 and close the connection
      ChannelFuture future =
          ctx.writeAndFlush(TOO_LARGE.duplicate().retain())
              .addListener(
                  new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                      if (!future.isSuccess()) {
                        logger.debug(
                            "Failed to send a 413 Request Entity Too Large.", future.cause());
                        ctx.close();
                      }
                    }
                  });

      // If the client started to send data already, close because it's impossible to recover.
      // If keep-alive is off and 'Expect: 100-continue' is missing, no need to leave the connection
      // open.
      if (oversized instanceof FullHttpMessage
          || !HttpUtil.is100ContinueExpected(oversized) && !HttpUtil.isKeepAlive(oversized)) {
        future.addListener(ChannelFutureListener.CLOSE);
      }

      // If an oversized request was handled properly and the connection is still alive
      // (i.e. rejected 100-continue). the decoder should prepare to handle a new message.
      HttpObjectDecoder decoder = ctx.pipeline().get(HttpObjectDecoder.class);
      if (decoder != null) {
        decoder.reset();
      }
    } else if (oversized instanceof HttpResponse) {
      ctx.close();
      throw new TooLongFrameException("Response entity too large: " + oversized);
    } else {
      throw new IllegalStateException();
    }
  }

  private abstract static class AggregatedFullHttpMessage
      implements ByteBufHolder, FullHttpMessage {
    protected final HttpMessage message;
    private final ByteBuf content;
    private HttpHeaders trailingHeaders;

    AggregatedFullHttpMessage(HttpMessage message, ByteBuf content, HttpHeaders trailingHeaders) {
      this.message = message;
      this.content = content;
      this.trailingHeaders = trailingHeaders;
    }

    @Override
    public HttpHeaders trailingHeaders() {
      HttpHeaders trailingHeaders = this.trailingHeaders;
      if (trailingHeaders == null) {
        return EmptyHttpHeaders.INSTANCE;
      } else {
        return trailingHeaders;
      }
    }

    void setTrailingHeaders(HttpHeaders trailingHeaders) {
      this.trailingHeaders = trailingHeaders;
    }

    @Override
    public HttpVersion protocolVersion() {
      return message.protocolVersion();
    }

    @Override
    public FullHttpMessage setProtocolVersion(HttpVersion version) {
      message.setProtocolVersion(version);
      return this;
    }

    @Override
    public HttpHeaders headers() {
      return message.headers();
    }

    @Override
    public DecoderResult decoderResult() {
      return message.decoderResult();
    }

    @Override
    public void setDecoderResult(DecoderResult result) {
      message.setDecoderResult(result);
    }

    @Override
    public ByteBuf content() {
      return content;
    }

    @Override
    public int refCnt() {
      return content.refCnt();
    }

    @Override
    public FullHttpMessage retain() {
      content.retain();
      return this;
    }

    @Override
    public FullHttpMessage retain(int increment) {
      content.retain(increment);
      return this;
    }

    @Override
    public FullHttpMessage touch(Object hint) {
      content.touch(hint);
      return this;
    }

    @Override
    public FullHttpMessage touch() {
      content.touch();
      return this;
    }

    @Override
    public boolean release() {
      return content.release();
    }

    @Override
    public boolean release(int decrement) {
      return content.release(decrement);
    }

    @Override
    public abstract FullHttpMessage copy();

    @Override
    public abstract FullHttpMessage duplicate();
  }

  private static final class AggregatedFullHttpRequest extends AggregatedFullHttpMessage
      implements FullHttpRequest {

    AggregatedFullHttpRequest(HttpRequest request, ByteBuf content, HttpHeaders trailingHeaders) {
      super(request, content, trailingHeaders);
    }

    /**
     * Copy this object
     *
     * @param copyContent
     *     <ul>
     *       <li>{@code true} if this object's {@link #content()} should be used to copy.
     *       <li>{@code false} if {@code newContent} should be used instead.
     *     </ul>
     *
     * @param newContent
     *     <ul>
     *       <li>if {@code copyContent} is false then this will be used in the copy's content.
     *       <li>if {@code null} then a default buffer of 0 size will be selected
     *     </ul>
     *
     * @return A copy of this object
     */
    private FullHttpRequest copy(boolean copyContent, ByteBuf newContent) {
      DefaultFullHttpRequest copy =
          new DefaultFullHttpRequest(
              protocolVersion(),
              method(),
              uri(),
              copyContent
                  ? content().copy()
                  : newContent == null ? Unpooled.buffer(0) : newContent);
      copy.headers().set(headers());
      copy.trailingHeaders().set(trailingHeaders());
      return copy;
    }

    @Override
    public FullHttpRequest copy(ByteBuf newContent) {
      return copy(false, newContent);
    }

    @Override
    public FullHttpRequest copy() {
      return copy(true, null);
    }

    @Override
    public FullHttpRequest duplicate() {
      DefaultFullHttpRequest duplicate =
          new DefaultFullHttpRequest(protocolVersion(), method(), uri(), content().duplicate());
      duplicate.headers().set(headers());
      duplicate.trailingHeaders().set(trailingHeaders());
      return duplicate;
    }

    @Override
    public FullHttpRequest retain(int increment) {
      super.retain(increment);
      return this;
    }

    @Override
    public FullHttpRequest retain() {
      super.retain();
      return this;
    }

    @Override
    public FullHttpRequest touch() {
      super.touch();
      return this;
    }

    @Override
    public FullHttpRequest touch(Object hint) {
      super.touch(hint);
      return this;
    }

    @Override
    public FullHttpRequest setMethod(HttpMethod method) {
      ((HttpRequest) message).setMethod(method);
      return this;
    }

    @Override
    public FullHttpRequest setUri(String uri) {
      ((HttpRequest) message).setUri(uri);
      return this;
    }

    @Override
    public HttpMethod method() {
      return ((HttpRequest) message).method();
    }

    @Override
    public String uri() {
      return ((HttpRequest) message).uri();
    }

    @Override
    public FullHttpRequest setProtocolVersion(HttpVersion version) {
      super.setProtocolVersion(version);
      return this;
    }

    @Override
    public String toString() {
      return HttpMessageUtil.appendFullRequest(new StringBuilder(256), this).toString();
    }
  }

  private static final class AggregatedFullHttpResponse extends AggregatedFullHttpMessage
      implements FullHttpResponse {

    AggregatedFullHttpResponse(HttpResponse message, ByteBuf content, HttpHeaders trailingHeaders) {
      super(message, content, trailingHeaders);
    }

    /**
     * Copy this object
     *
     * @param copyContent
     *     <ul>
     *       <li>{@code true} if this object's {@link #content()} should be used to copy.
     *       <li>{@code false} if {@code newContent} should be used instead.
     *     </ul>
     *
     * @param newContent
     *     <ul>
     *       <li>if {@code copyContent} is false then this will be used in the copy's content.
     *       <li>if {@code null} then a default buffer of 0 size will be selected
     *     </ul>
     *
     * @return A copy of this object
     */
    private FullHttpResponse copy(boolean copyContent, ByteBuf newContent) {
      DefaultFullHttpResponse copy =
          new DefaultFullHttpResponse(
              protocolVersion(),
              status(),
              copyContent
                  ? content().copy()
                  : newContent == null ? Unpooled.buffer(0) : newContent);
      copy.headers().set(headers());
      copy.trailingHeaders().set(trailingHeaders());
      return copy;
    }

    @Override
    public FullHttpResponse copy(ByteBuf newContent) {
      return copy(false, newContent);
    }

    @Override
    public FullHttpResponse copy() {
      return copy(true, null);
    }

    @Override
    public FullHttpResponse duplicate() {
      DefaultFullHttpResponse duplicate =
          new DefaultFullHttpResponse(protocolVersion(), status(), content().duplicate());
      duplicate.headers().set(headers());
      duplicate.trailingHeaders().set(trailingHeaders());
      return duplicate;
    }

    @Override
    public FullHttpResponse setStatus(HttpResponseStatus status) {
      ((HttpResponse) message).setStatus(status);
      return this;
    }

    @Override
    public HttpResponseStatus status() {
      return ((HttpResponse) message).status();
    }

    @Override
    public FullHttpResponse setProtocolVersion(HttpVersion version) {
      super.setProtocolVersion(version);
      return this;
    }

    @Override
    public FullHttpResponse retain(int increment) {
      super.retain(increment);
      return this;
    }

    @Override
    public FullHttpResponse retain() {
      super.retain();
      return this;
    }

    @Override
    public FullHttpResponse touch(Object hint) {
      super.touch(hint);
      return this;
    }

    @Override
    public FullHttpResponse touch() {
      super.touch();
      return this;
    }

    @Override
    public String toString() {
      return HttpMessageUtil.appendFullResponse(new StringBuilder(256), this).toString();
    }
  }
}
Пример #14
0
public class SpdySessionHandlerTest {

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(SpdySessionHandlerTest.class);

  private static final int closeSignal = SpdyCodecUtil.SPDY_SETTINGS_MAX_ID;
  private static final SpdySettingsFrame closeMessage = new DefaultSpdySettingsFrame();

  static {
    closeMessage.setValue(closeSignal, 0);
  }

  private static void assertDataFrame(Object msg, int streamId, boolean last) {
    assertNotNull(msg);
    assertTrue(msg instanceof SpdyDataFrame);
    SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg;
    assertEquals(spdyDataFrame.getStreamId(), streamId);
    assertEquals(spdyDataFrame.isLast(), last);
  }

  private static void assertSynReply(Object msg, int streamId, boolean last, SpdyHeaders headers) {
    assertNotNull(msg);
    assertTrue(msg instanceof SpdySynReplyFrame);
    assertHeaders(msg, streamId, last, headers);
  }

  private static void assertRstStream(Object msg, int streamId, SpdyStreamStatus status) {
    assertNotNull(msg);
    assertTrue(msg instanceof SpdyRstStreamFrame);
    SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg;
    assertEquals(spdyRstStreamFrame.getStreamId(), streamId);
    assertEquals(spdyRstStreamFrame.getStatus(), status);
  }

  private static void assertPing(Object msg, int id) {
    assertNotNull(msg);
    assertTrue(msg instanceof SpdyPingFrame);
    SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg;
    assertEquals(spdyPingFrame.getId(), id);
  }

  private static void assertGoAway(Object msg, int lastGoodStreamId) {
    assertNotNull(msg);
    assertTrue(msg instanceof SpdyGoAwayFrame);
    SpdyGoAwayFrame spdyGoAwayFrame = (SpdyGoAwayFrame) msg;
    assertEquals(spdyGoAwayFrame.getLastGoodStreamId(), lastGoodStreamId);
  }

  private static void assertHeaders(Object msg, int streamId, boolean last, SpdyHeaders headers) {
    assertNotNull(msg);
    assertTrue(msg instanceof SpdyHeadersFrame);
    SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg;
    assertEquals(spdyHeadersFrame.getStreamId(), streamId);
    assertEquals(spdyHeadersFrame.isLast(), last);
    for (String name : headers.names()) {
      List<String> expectedValues = headers.getAll(name);
      List<String> receivedValues = spdyHeadersFrame.headers().getAll(name);
      assertTrue(receivedValues.containsAll(expectedValues));
      receivedValues.removeAll(expectedValues);
      assertTrue(receivedValues.isEmpty());
      spdyHeadersFrame.headers().remove(name);
    }
    assertTrue(spdyHeadersFrame.headers().isEmpty());
  }

  private static void testSpdySessionHandler(SpdyVersion version, boolean server) {
    EmbeddedChannel sessionHandler =
        new EmbeddedChannel(
            new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server));

    while (sessionHandler.readOutbound() != null) {
      continue;
    }

    int localStreamId = server ? 1 : 2;
    int remoteStreamId = server ? 2 : 1;

    SpdySynStreamFrame spdySynStreamFrame =
        new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0);
    spdySynStreamFrame.headers().set("Compression", "test");

    SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId);
    spdyDataFrame.setLast(true);

    // Check if session handler returns INVALID_STREAM if it receives
    // a data frame for a Stream-ID that is not open
    sessionHandler.writeInbound(new DefaultSpdyDataFrame(localStreamId));
    assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.INVALID_STREAM);
    assertNull(sessionHandler.readOutbound());

    // Check if session handler returns PROTOCOL_ERROR if it receives
    // a data frame for a Stream-ID before receiving a SYN_REPLY frame
    sessionHandler.writeInbound(new DefaultSpdyDataFrame(remoteStreamId));
    assertRstStream(sessionHandler.readOutbound(), remoteStreamId, SpdyStreamStatus.PROTOCOL_ERROR);
    assertNull(sessionHandler.readOutbound());
    remoteStreamId += 2;

    // Check if session handler returns PROTOCOL_ERROR if it receives
    // multiple SYN_REPLY frames for the same active Stream-ID
    sessionHandler.writeInbound(new DefaultSpdySynReplyFrame(remoteStreamId));
    assertNull(sessionHandler.readOutbound());
    sessionHandler.writeInbound(new DefaultSpdySynReplyFrame(remoteStreamId));
    assertRstStream(sessionHandler.readOutbound(), remoteStreamId, SpdyStreamStatus.STREAM_IN_USE);
    assertNull(sessionHandler.readOutbound());
    remoteStreamId += 2;

    // Check if frame codec correctly compresses/uncompresses headers
    sessionHandler.writeInbound(spdySynStreamFrame);
    assertSynReply(
        sessionHandler.readOutbound(), localStreamId, false, spdySynStreamFrame.headers());
    assertNull(sessionHandler.readOutbound());
    SpdyHeadersFrame spdyHeadersFrame = new DefaultSpdyHeadersFrame(localStreamId);

    spdyHeadersFrame.headers().add("HEADER", "test1");
    spdyHeadersFrame.headers().add("HEADER", "test2");

    sessionHandler.writeInbound(spdyHeadersFrame);
    assertHeaders(sessionHandler.readOutbound(), localStreamId, false, spdyHeadersFrame.headers());
    assertNull(sessionHandler.readOutbound());
    localStreamId += 2;

    // Check if session handler closed the streams using the number
    // of concurrent streams and that it returns REFUSED_STREAM
    // if it receives a SYN_STREAM frame it does not wish to accept
    spdySynStreamFrame.setStreamId(localStreamId);
    spdySynStreamFrame.setLast(true);
    spdySynStreamFrame.setUnidirectional(true);

    sessionHandler.writeInbound(spdySynStreamFrame);
    assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.REFUSED_STREAM);
    assertNull(sessionHandler.readOutbound());

    // Check if session handler rejects HEADERS for closed streams
    int testStreamId = spdyDataFrame.getStreamId();
    sessionHandler.writeInbound(spdyDataFrame);
    assertDataFrame(sessionHandler.readOutbound(), testStreamId, spdyDataFrame.isLast());
    assertNull(sessionHandler.readOutbound());
    spdyHeadersFrame.setStreamId(testStreamId);

    sessionHandler.writeInbound(spdyHeadersFrame);
    assertRstStream(sessionHandler.readOutbound(), testStreamId, SpdyStreamStatus.INVALID_STREAM);
    assertNull(sessionHandler.readOutbound());

    // Check if session handler drops active streams if it receives
    // a RST_STREAM frame for that Stream-ID
    sessionHandler.writeInbound(new DefaultSpdyRstStreamFrame(remoteStreamId, 3));
    assertNull(sessionHandler.readOutbound());
    remoteStreamId += 2;

    // Check if session handler honors UNIDIRECTIONAL streams
    spdySynStreamFrame.setLast(false);
    sessionHandler.writeInbound(spdySynStreamFrame);
    assertNull(sessionHandler.readOutbound());
    spdySynStreamFrame.setUnidirectional(false);

    // Check if session handler returns PROTOCOL_ERROR if it receives
    // multiple SYN_STREAM frames for the same active Stream-ID
    sessionHandler.writeInbound(spdySynStreamFrame);
    assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR);
    assertNull(sessionHandler.readOutbound());
    localStreamId += 2;

    // Check if session handler returns PROTOCOL_ERROR if it receives
    // a SYN_STREAM frame with an invalid Stream-ID
    spdySynStreamFrame.setStreamId(localStreamId - 1);
    sessionHandler.writeInbound(spdySynStreamFrame);
    assertRstStream(
        sessionHandler.readOutbound(), localStreamId - 1, SpdyStreamStatus.PROTOCOL_ERROR);
    assertNull(sessionHandler.readOutbound());
    spdySynStreamFrame.setStreamId(localStreamId);

    // Check if session handler returns PROTOCOL_ERROR if it receives
    // an invalid HEADERS frame
    spdyHeadersFrame.setStreamId(localStreamId);

    spdyHeadersFrame.setInvalid();
    sessionHandler.writeInbound(spdyHeadersFrame);
    assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR);
    assertNull(sessionHandler.readOutbound());

    sessionHandler.finish();
  }

  private static void testSpdySessionHandlerPing(SpdyVersion version, boolean server) {
    EmbeddedChannel sessionHandler =
        new EmbeddedChannel(
            new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server));

    while (sessionHandler.readOutbound() != null) {
      continue;
    }

    int localStreamId = server ? 1 : 2;
    int remoteStreamId = server ? 2 : 1;

    SpdyPingFrame localPingFrame = new DefaultSpdyPingFrame(localStreamId);
    SpdyPingFrame remotePingFrame = new DefaultSpdyPingFrame(remoteStreamId);

    // Check if session handler returns identical local PINGs
    sessionHandler.writeInbound(localPingFrame);
    assertPing(sessionHandler.readOutbound(), localPingFrame.getId());
    assertNull(sessionHandler.readOutbound());

    // Check if session handler ignores un-initiated remote PINGs
    sessionHandler.writeInbound(remotePingFrame);
    assertNull(sessionHandler.readOutbound());

    sessionHandler.finish();
  }

  private static void testSpdySessionHandlerGoAway(SpdyVersion version, boolean server) {
    EmbeddedChannel sessionHandler =
        new EmbeddedChannel(
            new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server));

    while (sessionHandler.readOutbound() != null) {
      continue;
    }

    int localStreamId = server ? 1 : 2;

    SpdySynStreamFrame spdySynStreamFrame =
        new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0);
    spdySynStreamFrame.headers().set("Compression", "test");

    SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId);
    spdyDataFrame.setLast(true);

    // Send an initial request
    sessionHandler.writeInbound(spdySynStreamFrame);
    assertSynReply(
        sessionHandler.readOutbound(), localStreamId, false, spdySynStreamFrame.headers());
    assertNull(sessionHandler.readOutbound());
    sessionHandler.writeInbound(spdyDataFrame);
    assertDataFrame(sessionHandler.readOutbound(), localStreamId, true);
    assertNull(sessionHandler.readOutbound());

    // Check if session handler sends a GOAWAY frame when closing
    sessionHandler.writeInbound(closeMessage);
    assertGoAway(sessionHandler.readOutbound(), localStreamId);
    assertNull(sessionHandler.readOutbound());
    localStreamId += 2;

    // Check if session handler returns REFUSED_STREAM if it receives
    // SYN_STREAM frames after sending a GOAWAY frame
    spdySynStreamFrame.setStreamId(localStreamId);
    sessionHandler.writeInbound(spdySynStreamFrame);
    assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.REFUSED_STREAM);
    assertNull(sessionHandler.readOutbound());

    // Check if session handler ignores Data frames after sending
    // a GOAWAY frame
    spdyDataFrame.setStreamId(localStreamId);
    sessionHandler.writeInbound(spdyDataFrame);
    assertNull(sessionHandler.readOutbound());

    sessionHandler.finish();
  }

  @Test
  public void testSpdyClientSessionHandler() {
    logger.info("Running: testSpdyClientSessionHandler v3.1");
    testSpdySessionHandler(SpdyVersion.SPDY_3_1, false);
  }

  @Test
  public void testSpdyClientSessionHandlerPing() {
    logger.info("Running: testSpdyClientSessionHandlerPing v3.1");
    testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, false);
  }

  @Test
  public void testSpdyClientSessionHandlerGoAway() {
    logger.info("Running: testSpdyClientSessionHandlerGoAway v3.1");
    testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, false);
  }

  @Test
  public void testSpdyServerSessionHandler() {
    logger.info("Running: testSpdyServerSessionHandler v3.1");
    testSpdySessionHandler(SpdyVersion.SPDY_3_1, true);
  }

  @Test
  public void testSpdyServerSessionHandlerPing() {
    logger.info("Running: testSpdyServerSessionHandlerPing v3.1");
    testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, true);
  }

  @Test
  public void testSpdyServerSessionHandlerGoAway() {
    logger.info("Running: testSpdyServerSessionHandlerGoAway v3.1");
    testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, true);
  }

  // Echo Handler opens 4 half-closed streams on session connection
  // and then sets the number of concurrent streams to 1
  private static class EchoHandler extends ChannelHandlerAdapter {
    private final int closeSignal;
    private final boolean server;

    EchoHandler(int closeSignal, boolean server) {
      this.closeSignal = closeSignal;
      this.server = server;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
      // Initiate 4 new streams
      int streamId = server ? 2 : 1;
      SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(streamId, 0, (byte) 0);
      spdySynStreamFrame.setLast(true);
      ctx.writeAndFlush(spdySynStreamFrame);
      spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2);
      ctx.writeAndFlush(spdySynStreamFrame);
      spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2);
      ctx.writeAndFlush(spdySynStreamFrame);
      spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2);
      ctx.writeAndFlush(spdySynStreamFrame);

      // Limit the number of concurrent streams to 1
      SpdySettingsFrame spdySettingsFrame = new DefaultSpdySettingsFrame();
      spdySettingsFrame.setValue(SpdySettingsFrame.SETTINGS_MAX_CONCURRENT_STREAMS, 1);
      ctx.writeAndFlush(spdySettingsFrame);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
      if (msg instanceof SpdySynStreamFrame) {

        SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg;
        if (!spdySynStreamFrame.isUnidirectional()) {
          int streamId = spdySynStreamFrame.getStreamId();
          SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId);
          spdySynReplyFrame.setLast(spdySynStreamFrame.isLast());
          for (Map.Entry<String, String> entry : spdySynStreamFrame.headers()) {
            spdySynReplyFrame.headers().add(entry.getKey(), entry.getValue());
          }

          ctx.writeAndFlush(spdySynReplyFrame);
        }
        return;
      }

      if (msg instanceof SpdySynReplyFrame) {
        return;
      }

      if (msg instanceof SpdyDataFrame
          || msg instanceof SpdyPingFrame
          || msg instanceof SpdyHeadersFrame) {

        ctx.writeAndFlush(msg);
        return;
      }

      if (msg instanceof SpdySettingsFrame) {
        SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg;
        if (spdySettingsFrame.isSet(closeSignal)) {
          ctx.close();
        }
      }
    }
  }
}
Пример #15
0
/**
 * 微信请求处理类
 *
 * @className WeixinRequestHandler
 * @author jy
 * @date 2014年11月16日
 * @since JDK 1.7
 * @see com.foxinmy.weixin4j.dispatcher.WeixinMessageDispatcher
 */
public class WeixinRequestHandler extends SimpleChannelInboundHandler<WeixinRequest> {
  private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass());

  private final WeixinMessageDispatcher messageDispatcher;

  public WeixinRequestHandler(WeixinMessageDispatcher messageDispatcher) {
    this.messageDispatcher = messageDispatcher;
  }

  public void channelReadComplete(ChannelHandlerContext ctx) {
    ctx.flush();
  }

  @Override
  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
    ctx.close();
    logger.error("catch the exception:", cause);
  }

  @Override
  protected void channelRead0(ChannelHandlerContext ctx, WeixinRequest request)
      throws WeixinException {
    final AesToken aesToken = request.getAesToken();
    if (aesToken == null
        || (StringUtil.isBlank(request.getSignature())
            && StringUtil.isBlank(request.getMsgSignature()))) {
      ctx.writeAndFlush(HttpUtil.createHttpResponse(BAD_REQUEST))
          .addListener(ChannelFutureListener.CLOSE);
      return;
    }
    /** 公众平台:无论Get,Post都带signature参数,当开启aes模式时带msg_signature参数 企业号:无论Get,Post都带msg_signature参数 */
    if (request.getMethod().equals(HttpMethod.GET.name())) {
      if (!StringUtil.isBlank(request.getSignature())
          && MessageUtil.signature(aesToken.getToken(), request.getTimeStamp(), request.getNonce())
              .equals(request.getSignature())) {
        ctx.write(new SingleResponse(request.getEchoStr()));
        return;
      }
      if (!StringUtil.isBlank(request.getMsgSignature())
          && MessageUtil.signature(
                  aesToken.getToken(),
                  request.getTimeStamp(),
                  request.getNonce(),
                  request.getEchoStr())
              .equals(request.getMsgSignature())) {
        ctx.write(
            new SingleResponse(
                MessageUtil.aesDecrypt(null, aesToken.getAesKey(), request.getEchoStr())));
        return;
      }
      ctx.writeAndFlush(HttpUtil.createHttpResponse(FORBIDDEN))
          .addListener(ChannelFutureListener.CLOSE);
      return;
    } else if (request.getMethod().equals(HttpMethod.POST.name())) {
      if (!StringUtil.isBlank(request.getSignature())
          && !MessageUtil.signature(aesToken.getToken(), request.getTimeStamp(), request.getNonce())
              .equals(request.getSignature())) {
        ctx.writeAndFlush(HttpUtil.createHttpResponse(FORBIDDEN))
            .addListener(ChannelFutureListener.CLOSE);
        return;
      }
      if (request.getEncryptType() == EncryptType.AES
          && !MessageUtil.signature(
                  aesToken.getToken(),
                  request.getTimeStamp(),
                  request.getNonce(),
                  request.getEncryptContent())
              .equals(request.getMsgSignature())) {
        ctx.writeAndFlush(HttpUtil.createHttpResponse(FORBIDDEN))
            .addListener(ChannelFutureListener.CLOSE);
        return;
      }
    } else {
      ctx.writeAndFlush(HttpUtil.createHttpResponse(METHOD_NOT_ALLOWED))
          .addListener(ChannelFutureListener.CLOSE);
      return;
    }
    WeixinMessageTransfer messageTransfer = MessageTransferHandler.parser(request);
    ctx.channel().attr(Consts.MESSAGE_TRANSFER_KEY).set(messageTransfer);
    messageDispatcher.doDispatch(ctx, request, messageTransfer);
  }
}
Пример #16
0
/**
 * Adds <a href="http://en.wikipedia.org/wiki/Transport_Layer_Security">SSL &middot; TLS</a> and
 * StartTLS support to a {@link Channel}. Please refer to the <strong>"SecureChat"</strong> example
 * in the distribution or the web site for the detailed usage.
 *
 * <h3>Beginning the handshake</h3>
 *
 * <p>You must make sure not to write a message while the handshake is in progress unless you are
 * renegotiating. You will be notified by the {@link Future} which is returned by the {@link
 * #handshakeFuture()} method when the handshake process succeeds or fails.
 *
 * <p>Beside using the handshake {@link ChannelFuture} to get notified about the completation of the
 * handshake it's also possible to detect it by implement the {@link
 * ChannelHandler#userEventTriggered(ChannelHandlerContext, Object)} method and check for a {@link
 * SslHandshakeCompletionEvent}.
 *
 * <h3>Handshake</h3>
 *
 * <p>The handshake will be automaticly issued for you once the {@link Channel} is active and {@link
 * SSLEngine#getUseClientMode()} returns {@code true}. So no need to bother with it by your self.
 *
 * <h3>Closing the session</h3>
 *
 * <p>To close the SSL session, the {@link #close()} method should be called to send the {@code
 * close_notify} message to the remote peer. One exception is when you close the {@link Channel} -
 * {@link SslHandler} intercepts the close request and send the {@code close_notify} message before
 * the channel closure automatically. Once the SSL session is closed, it is not reusable, and
 * consequently you should create a new {@link SslHandler} with a new {@link SSLEngine} as explained
 * in the following section.
 *
 * <h3>Restarting the session</h3>
 *
 * <p>To restart the SSL session, you must remove the existing closed {@link SslHandler} from the
 * {@link ChannelPipeline}, insert a new {@link SslHandler} with a new {@link SSLEngine} into the
 * pipeline, and start the handshake process as described in the first section.
 *
 * <h3>Implementing StartTLS</h3>
 *
 * <p><a href="http://en.wikipedia.org/wiki/STARTTLS">StartTLS</a> is the communication pattern that
 * secures the wire in the middle of the plaintext connection. Please note that it is different from
 * SSL &middot; TLS, that secures the wire from the beginning of the connection. Typically, StartTLS
 * is composed of three steps:
 *
 * <ol>
 *   <li>Client sends a StartTLS request to server.
 *   <li>Server sends a StartTLS response to client.
 *   <li>Client begins SSL handshake.
 * </ol>
 *
 * If you implement a server, you need to:
 *
 * <ol>
 *   <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code true},
 *   <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and
 *   <li>write a StartTLS response.
 * </ol>
 *
 * Please note that you must insert {@link SslHandler} <em>before</em> sending the StartTLS
 * response. Otherwise the client can send begin SSL handshake before {@link SslHandler} is inserted
 * to the {@link ChannelPipeline}, causing data corruption.
 *
 * <p>The client-side implementation is much simpler.
 *
 * <ol>
 *   <li>Write a StartTLS request,
 *   <li>wait for the StartTLS response,
 *   <li>create a new {@link SslHandler} instance with {@code startTls} flag set to {@code false},
 *   <li>insert the {@link SslHandler} to the {@link ChannelPipeline}, and
 *   <li>Initiate SSL handshake.
 * </ol>
 *
 * <h3>Known issues</h3>
 *
 * <p>Because of a known issue with the current implementation of the SslEngine that comes with Java
 * it may be possible that you see blocked IO-Threads while a full GC is done.
 *
 * <p>So if you are affected you can workaround this problem by adjust the cache settings like shown
 * below:
 *
 * <pre>
 *     SslContext context = ...;
 *     context.getServerSessionContext().setSessionCacheSize(someSaneSize);
 *     context.getServerSessionContext().setSessionTime(someSameTimeout);
 * </pre>
 *
 * <p>What values to use here depends on the nature of your application and should be set based on
 * monitoring and debugging of it. For more details see <a
 * href="https://github.com/netty/netty/issues/832">#832</a> in our issue tracker.
 */
public class SslHandler extends ByteToMessageDecoder {

  private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslHandler.class);

  private static final Pattern IGNORABLE_CLASS_IN_STACK =
      Pattern.compile("^.*(?:Socket|Datagram|Sctp|Udt)Channel.*$");
  private static final Pattern IGNORABLE_ERROR_MESSAGE =
      Pattern.compile(
          "^.*(?:connection.*(?:reset|closed|abort|broken)|broken.*pipe).*$",
          Pattern.CASE_INSENSITIVE);

  private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already");
  private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out");
  private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException();

  static {
    SSLENGINE_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
    HANDSHAKE_TIMED_OUT.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
    CHANNEL_CLOSED.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
  }

  private volatile ChannelHandlerContext ctx;
  private final SSLEngine engine;
  private final int maxPacketBufferSize;

  private final boolean startTls;
  private boolean sentFirstMessage;
  private boolean flushedBeforeHandshakeDone;
  private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
  private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
  private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();

  /**
   * Set by wrap*() methods when something is produced. {@link
   * #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call
   * ctx.flush().
   */
  private boolean needsFlush;

  private int packetLength;
  private ByteBuf decodeOut;

  private volatile long handshakeTimeoutMillis = 10000;
  private volatile long closeNotifyTimeoutMillis = 3000;

  /**
   * Creates a new instance.
   *
   * @param engine the {@link SSLEngine} this handler will use
   */
  public SslHandler(SSLEngine engine) {
    this(engine, false);
  }

  /**
   * Creates a new instance.
   *
   * @param engine the {@link SSLEngine} this handler will use
   * @param startTls {@code true} if the first write request shouldn't be encrypted by the {@link
   *     SSLEngine}
   */
  public SslHandler(SSLEngine engine, boolean startTls) {
    if (engine == null) {
      throw new NullPointerException("engine");
    }
    this.engine = engine;
    this.startTls = startTls;
    maxPacketBufferSize = engine.getSession().getPacketBufferSize();
  }

  public long getHandshakeTimeoutMillis() {
    return handshakeTimeoutMillis;
  }

  public void setHandshakeTimeout(long handshakeTimeout, TimeUnit unit) {
    if (unit == null) {
      throw new NullPointerException("unit");
    }

    setHandshakeTimeoutMillis(unit.toMillis(handshakeTimeout));
  }

  public void setHandshakeTimeoutMillis(long handshakeTimeoutMillis) {
    if (handshakeTimeoutMillis < 0) {
      throw new IllegalArgumentException(
          "handshakeTimeoutMillis: " + handshakeTimeoutMillis + " (expected: >= 0)");
    }
    this.handshakeTimeoutMillis = handshakeTimeoutMillis;
  }

  public long getCloseNotifyTimeoutMillis() {
    return closeNotifyTimeoutMillis;
  }

  public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) {
    if (unit == null) {
      throw new NullPointerException("unit");
    }

    setCloseNotifyTimeoutMillis(unit.toMillis(closeNotifyTimeout));
  }

  public void setCloseNotifyTimeoutMillis(long closeNotifyTimeoutMillis) {
    if (closeNotifyTimeoutMillis < 0) {
      throw new IllegalArgumentException(
          "closeNotifyTimeoutMillis: " + closeNotifyTimeoutMillis + " (expected: >= 0)");
    }
    this.closeNotifyTimeoutMillis = closeNotifyTimeoutMillis;
  }

  /** Returns the {@link SSLEngine} which is used by this handler. */
  public SSLEngine engine() {
    return engine;
  }

  /** Returns a {@link Future} that will get notified once the handshake completes. */
  public Future<Channel> handshakeFuture() {
    return handshakePromise;
  }

  /**
   * Sends an SSL {@code close_notify} message to the specified channel and destroys the underlying
   * {@link SSLEngine}.
   */
  public ChannelFuture close() {
    return close(ctx.newPromise());
  }

  /** See {@link #close()} */
  public ChannelFuture close(final ChannelPromise future) {
    final ChannelHandlerContext ctx = this.ctx;
    ctx.executor()
        .execute(
            new Runnable() {
              @Override
              public void run() {
                engine.closeOutbound();
                try {
                  write(ctx, Unpooled.EMPTY_BUFFER, future);
                  flush(ctx);
                } catch (Exception e) {
                  if (!future.tryFailure(e)) {
                    logger.warn("flush() raised a masked exception.", e);
                  }
                }
              }
            });

    return future;
  }

  /**
   * Return the {@link ChannelFuture} that will get notified if the inbound of the {@link SSLEngine}
   * will get closed.
   *
   * <p>This method will return the same {@link ChannelFuture} all the time.
   *
   * <p>For more informations see the apidocs of {@link SSLEngine}
   */
  public Future<Channel> sslCloseFuture() {
    return sslCloseFuture;
  }

  @Override
  public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
    if (decodeOut != null) {
      decodeOut.release();
      decodeOut = null;
    }
    for (; ; ) {
      PendingWrite write = pendingUnencryptedWrites.poll();
      if (write == null) {
        break;
      }
      write.failAndRecycle(new ChannelException("Pending write on removal of SslHandler"));
    }
  }

  @Override
  public void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise)
      throws Exception {
    closeOutboundAndChannel(ctx, promise, true);
  }

  @Override
  public void close(final ChannelHandlerContext ctx, final ChannelPromise promise)
      throws Exception {
    closeOutboundAndChannel(ctx, promise, false);
  }

  @Override
  public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
      throws Exception {
    pendingUnencryptedWrites.add(PendingWrite.newInstance(msg, promise));
  }

  @Override
  public void flush(ChannelHandlerContext ctx) throws Exception {
    // Do not encrypt the first write request if this handler is
    // created with startTLS flag turned on.
    if (startTls && !sentFirstMessage) {
      sentFirstMessage = true;
      for (; ; ) {
        PendingWrite pendingWrite = pendingUnencryptedWrites.poll();
        if (pendingWrite == null) {
          break;
        }
        ctx.write(pendingWrite.msg(), (ChannelPromise) pendingWrite.recycleAndGet());
      }
      ctx.flush();
      return;
    }
    if (pendingUnencryptedWrites.isEmpty()) {
      pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null));
    }
    if (!handshakePromise.isDone()) {
      flushedBeforeHandshakeDone = true;
    }
    wrap(ctx, false);
    ctx.flush();
  }

  private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
    ByteBuf out = null;
    ChannelPromise promise = null;
    try {
      for (; ; ) {
        PendingWrite pending = pendingUnencryptedWrites.peek();
        if (pending == null) {
          break;
        }
        if (out == null) {
          out = ctx.alloc().buffer(maxPacketBufferSize);
        }

        if (!(pending.msg() instanceof ByteBuf)) {
          ctx.write(pending.msg(), (ChannelPromise) pending.recycleAndGet());
          pendingUnencryptedWrites.remove();
          continue;
        }
        ByteBuf buf = (ByteBuf) pending.msg();
        SSLEngineResult result = wrap(engine, buf, out);

        if (!buf.isReadable()) {
          buf.release();
          promise = (ChannelPromise) pending.recycleAndGet();
          pendingUnencryptedWrites.remove();
        } else {
          promise = null;
        }

        if (result.getStatus() == Status.CLOSED) {
          // SSLEngine has been closed already.
          // Any further write attempts should be denied.
          for (; ; ) {
            PendingWrite w = pendingUnencryptedWrites.poll();
            if (w == null) {
              break;
            }
            w.failAndRecycle(SSLENGINE_CLOSED);
          }
          return;
        } else {
          switch (result.getHandshakeStatus()) {
            case NEED_TASK:
              runDelegatedTasks();
              break;
            case FINISHED:
              setHandshakeSuccess();
              // deliberate fall-through
            case NOT_HANDSHAKING:
              setHandshakeSuccessIfStillHandshaking();
              // deliberate fall-through
            case NEED_WRAP:
              finishWrap(ctx, out, promise, inUnwrap);
              promise = null;
              out = null;
              break;
            case NEED_UNWRAP:
              return;
            default:
              throw new IllegalStateException(
                  "Unknown handshake status: " + result.getHandshakeStatus());
          }
        }
      }
    } catch (SSLException e) {
      setHandshakeFailure(e);
      throw e;
    } finally {
      finishWrap(ctx, out, promise, inUnwrap);
    }
  }

  private void finishWrap(
      ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) {
    if (out == null) {
      out = Unpooled.EMPTY_BUFFER;
    } else if (!out.isReadable()) {
      out.release();
      out = Unpooled.EMPTY_BUFFER;
    }

    if (promise != null) {
      ctx.write(out, promise);
    } else {
      ctx.write(out);
    }

    if (inUnwrap) {
      needsFlush = true;
    }
  }

  private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
    ByteBuf out = null;
    try {
      for (; ; ) {
        if (out == null) {
          out = ctx.alloc().buffer(maxPacketBufferSize);
        }
        SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);

        if (result.bytesProduced() > 0) {
          ctx.write(out);
          if (inUnwrap) {
            needsFlush = true;
          }
          out = null;
        }

        switch (result.getHandshakeStatus()) {
          case FINISHED:
            setHandshakeSuccess();
            break;
          case NEED_TASK:
            runDelegatedTasks();
            break;
          case NEED_UNWRAP:
            if (!inUnwrap) {
              unwrapNonApp(ctx);
            }
            break;
          case NEED_WRAP:
            break;
          case NOT_HANDSHAKING:
            setHandshakeSuccessIfStillHandshaking();
            // Workaround for TLS False Start problem reported at:
            // https://github.com/netty/netty/issues/1108#issuecomment-14266970
            if (!inUnwrap) {
              unwrapNonApp(ctx);
            }
            break;
          default:
            throw new IllegalStateException(
                "Unknown handshake status: " + result.getHandshakeStatus());
        }

        if (result.bytesProduced() == 0) {
          break;
        }
      }
    } catch (SSLException e) {
      setHandshakeFailure(e);
      throw e;
    } finally {
      if (out != null) {
        out.release();
      }
    }
  }

  private SSLEngineResult wrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
    ByteBuffer in0 = in.nioBuffer();
    for (; ; ) {
      ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
      SSLEngineResult result = engine.wrap(in0, out0);
      in.skipBytes(result.bytesConsumed());
      out.writerIndex(out.writerIndex() + result.bytesProduced());

      switch (result.getStatus()) {
        case BUFFER_OVERFLOW:
          out.ensureWritable(maxPacketBufferSize);
          break;
        default:
          return result;
      }
    }
  }

  @Override
  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
    // Make sure to release SSLEngine,
    // and notify the handshake future if the connection has been closed during handshake.
    setHandshakeFailure(CHANNEL_CLOSED);
    super.channelInactive(ctx);
  }

  @Override
  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    if (ignoreException(cause)) {
      // It is safe to ignore the 'connection reset by peer' or
      // 'broken pipe' error after sending close_notify.
      if (logger.isDebugEnabled()) {
        logger.debug(
            "Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred "
                + "while writing close_notify in response to the peer's close_notify",
            cause);
      }

      // Close the connection explicitly just in case the transport
      // did not close the connection automatically.
      if (ctx.channel().isActive()) {
        ctx.close();
      }
    } else {
      ctx.fireExceptionCaught(cause);
    }
  }

  /**
   * Checks if the given {@link Throwable} can be ignore and just "swallowed"
   *
   * <p>When an ssl connection is closed a close_notify message is sent. After that the peer also
   * sends close_notify however, it's not mandatory to receive the close_notify. The party who sent
   * the initial close_notify can close the connection immediately then the peer will get connection
   * reset error.
   */
  private boolean ignoreException(Throwable t) {
    if (!(t instanceof SSLException) && t instanceof IOException && sslCloseFuture.isDone()) {
      String message = String.valueOf(t.getMessage()).toLowerCase();

      // first try to match connection reset / broke peer based on the regex. This is the fastest
      // way
      // but may fail on different jdk impls or OS's
      if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) {
        return true;
      }

      // Inspect the StackTraceElements to see if it was a connection reset / broken pipe or not
      StackTraceElement[] elements = t.getStackTrace();
      for (StackTraceElement element : elements) {
        String classname = element.getClassName();
        String methodname = element.getMethodName();

        // skip all classes that belong to the io.netty package
        if (classname.startsWith("io.netty.")) {
          continue;
        }

        // check if the method name is read if not skip it
        if (!"read".equals(methodname)) {
          continue;
        }

        // This will also match against SocketInputStream which is used by openjdk 7 and maybe
        // also others
        if (IGNORABLE_CLASS_IN_STACK.matcher(classname).matches()) {
          return true;
        }

        try {
          // No match by now.. Try to load the class via classloader and inspect it.
          // This is mainly done as other JDK implementations may differ in name of
          // the impl.
          Class<?> clazz = PlatformDependent.getClassLoader(getClass()).loadClass(classname);

          if (SocketChannel.class.isAssignableFrom(clazz)
              || DatagramChannel.class.isAssignableFrom(clazz)) {
            return true;
          }

          // also match against SctpChannel via String matching as it may not present.
          if (PlatformDependent.javaVersion() >= 7
              && "com.sun.nio.sctp.SctpChannel".equals(clazz.getSuperclass().getName())) {
            return true;
          }
        } catch (ClassNotFoundException e) {
          // This should not happen just ignore
        }
      }
    }

    return false;
  }

  /**
   * Returns {@code true} if the given {@link ByteBuf} is encrypted. Be aware that this method will
   * not increase the readerIndex of the given {@link ByteBuf}.
   *
   * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to
   *     read, otherwise it will throw an {@link IllegalArgumentException}.
   * @return encrypted {@code true} if the {@link ByteBuf} is encrypted, {@code false} otherwise.
   * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5
   *     bytes to read.
   */
  public static boolean isEncrypted(ByteBuf buffer) {
    if (buffer.readableBytes() < 5) {
      throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
    }
    return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
  }

  /**
   * Return how much bytes can be read out of the encrypted data. Be aware that this method will not
   * increase the readerIndex of the given {@link ByteBuf}.
   *
   * @param buffer The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to
   *     read, otherwise it will throw an {@link IllegalArgumentException}.
   * @return length The length of the encrypted packet that is included in the buffer. This will
   *     return {@code -1} if the given {@link ByteBuf} is not encrypted at all.
   * @throws IllegalArgumentException Is thrown if the given {@link ByteBuf} has not at least 5
   *     bytes to read.
   */
  private static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
    int packetLength = 0;

    // SSLv3 or TLS - Check ContentType
    boolean tls;
    switch (buffer.getUnsignedByte(offset)) {
      case 20: // change_cipher_spec
      case 21: // alert
      case 22: // handshake
      case 23: // application_data
        tls = true;
        break;
      default:
        // SSLv2 or bad data
        tls = false;
    }

    if (tls) {
      // SSLv3 or TLS - Check ProtocolVersion
      int majorVersion = buffer.getUnsignedByte(offset + 1);
      if (majorVersion == 3) {
        // SSLv3 or TLS
        packetLength = buffer.getUnsignedShort(offset + 3) + 5;
        if (packetLength <= 5) {
          // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
          tls = false;
        }
      } else {
        // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
        tls = false;
      }
    }

    if (!tls) {
      // SSLv2 or bad data - Check the version
      boolean sslv2 = true;
      int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
      int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
      if (majorVersion == 2 || majorVersion == 3) {
        // SSLv2
        if (headerLength == 2) {
          packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
        } else {
          packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
        }
        if (packetLength <= headerLength) {
          sslv2 = false;
        }
      } else {
        sslv2 = false;
      }

      if (!sslv2) {
        return -1;
      }
    }
    return packetLength;
  }

  @Override
  protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
      throws SSLException {

    // Keeps the list of the length of every SSL record in the input buffer.
    int[] recordLengths = null;
    int nRecords = 0;

    final int startOffset = in.readerIndex();
    final int endOffset = in.writerIndex();
    int offset = startOffset;

    // If we calculated the length of the current SSL record before, use that information.
    if (packetLength > 0) {
      if (endOffset - startOffset < packetLength) {
        return;
      } else {
        recordLengths = new int[4];
        recordLengths[0] = packetLength;
        nRecords = 1;

        offset += packetLength;
        packetLength = 0;
      }
    }

    boolean nonSslRecord = false;

    for (; ; ) {
      final int readableBytes = endOffset - offset;
      if (readableBytes < 5) {
        break;
      }

      final int packetLength = getEncryptedPacketLength(in, offset);
      if (packetLength == -1) {
        nonSslRecord = true;
        break;
      }

      assert packetLength > 0;

      if (packetLength > readableBytes) {
        // wait until the whole packet can be read
        this.packetLength = packetLength;
        break;
      }

      // We have a whole packet.
      // Remember the length of the current packet.
      if (recordLengths == null) {
        recordLengths = new int[4];
      }
      if (nRecords == recordLengths.length) {
        recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
      }
      recordLengths[nRecords++] = packetLength;

      // Increment the offset to handle the next packet.
      offset += packetLength;
    }

    final int totalLength = offset - startOffset;
    if (totalLength > 0) {
      // The buffer contains one or more full SSL records.
      // Slice out the whole packet so unwrap will only be called with complete packets.
      // Also directly reset the packetLength. This is needed as unwrap(..) may trigger
      // decode(...) again via:
      // 1) unwrap(..) is called
      // 2) wrap(...) is called from within unwrap(...)
      // 3) wrap(...) calls unwrapLater(...)
      // 4) unwrapLater(...) calls decode(...)
      //
      // See https://github.com/netty/netty/issues/1534
      in.skipBytes(totalLength);
      ByteBuffer buffer = in.nioBuffer(startOffset, totalLength);
      unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out);
    }

    if (nonSslRecord) {
      // Not an SSL/TLS packet
      NotSslRecordException e =
          new NotSslRecordException("not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
      in.skipBytes(in.readableBytes());
      ctx.fireExceptionCaught(e);
      setHandshakeFailure(e);
    }
  }

  @Override
  public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
    if (needsFlush) {
      needsFlush = false;
      ctx.flush();
    }
    super.channelReadComplete(ctx);
  }

  /**
   * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle
   * handshakes, etc.
   */
  private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
    try {
      unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
    } finally {
      ByteBuf decodeOut = this.decodeOut;
      if (decodeOut != null && decodeOut.isReadable()) {
        this.decodeOut = null;
        ctx.fireChannelRead(decodeOut);
      }
    }
  }

  /** Unwraps multiple inbound SSL records. */
  private void unwrapMultiple(
      ChannelHandlerContext ctx,
      ByteBuffer packet,
      int totalLength,
      int[] recordLengths,
      int nRecords,
      List<Object> out)
      throws SSLException {
    for (int i = 0; i < nRecords; i++) {
      packet.limit(packet.position() + recordLengths[i]);
      try {
        unwrapSingle(ctx, packet, totalLength);
        assert !packet.hasRemaining();
      } finally {
        ByteBuf decodeOut = this.decodeOut;
        if (decodeOut != null && decodeOut.isReadable()) {
          this.decodeOut = null;
          out.add(decodeOut);
        }
      }
    }
  }

  /** Unwraps a single SSL record. */
  private void unwrapSingle(
      ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity)
      throws SSLException {

    boolean wrapLater = false;
    try {
      for (; ; ) {
        if (decodeOut == null) {
          decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
        }

        final SSLEngineResult result = unwrap(engine, packet, decodeOut);
        final Status status = result.getStatus();
        final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
        final int produced = result.bytesProduced();
        final int consumed = result.bytesConsumed();

        if (status == Status.CLOSED) {
          // notify about the CLOSED state of the SSLEngine. See #137
          sslCloseFuture.trySuccess(ctx.channel());
          break;
        }

        switch (handshakeStatus) {
          case NEED_UNWRAP:
            break;
          case NEED_WRAP:
            wrapNonAppData(ctx, true);
            break;
          case NEED_TASK:
            runDelegatedTasks();
            break;
          case FINISHED:
            setHandshakeSuccess();
            wrapLater = true;
            continue;
          case NOT_HANDSHAKING:
            if (setHandshakeSuccessIfStillHandshaking()) {
              wrapLater = true;
              continue;
            }
            if (flushedBeforeHandshakeDone) {
              // We need to call wrap(...) in case there was a flush done before the handshake
              // completed.
              //
              // See https://github.com/netty/netty/pull/2437
              flushedBeforeHandshakeDone = false;
              wrapLater = true;
            }

            break;
          default:
            throw new IllegalStateException("Unknown handshake status: " + handshakeStatus);
        }

        if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) {
          break;
        }
      }

      if (wrapLater) {
        wrap(ctx, true);
      }
    } catch (SSLException e) {
      setHandshakeFailure(e);
      throw e;
    }
  }

  private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out)
      throws SSLException {
    int overflows = 0;
    for (; ; ) {
      ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
      SSLEngineResult result = engine.unwrap(in, out0);
      out.writerIndex(out.writerIndex() + result.bytesProduced());
      switch (result.getStatus()) {
        case BUFFER_OVERFLOW:
          int max = engine.getSession().getApplicationBufferSize();
          switch (overflows++) {
            case 0:
              out.ensureWritable(Math.min(max, in.remaining()));
              break;
            default:
              out.ensureWritable(max);
          }
          break;
        default:
          return result;
      }
    }
  }

  /**
   * Fetches all delegated tasks from the {@link SSLEngine} and runs them by invoking them directly.
   */
  private void runDelegatedTasks() {
    for (; ; ) {
      Runnable task = engine.getDelegatedTask();
      if (task == null) {
        break;
      }

      task.run();
    }
  }

  /**
   * Works around some Android {@link SSLEngine} implementations that skip {@link
   * HandshakeStatus#FINISHED} and go straight into {@link HandshakeStatus#NOT_HANDSHAKING} when
   * handshake is finished.
   *
   * @return {@code true} if and only if the workaround has been applied and thus {@link
   *     #handshakeFuture} has been marked as success by this method
   */
  private boolean setHandshakeSuccessIfStillHandshaking() {
    if (!handshakePromise.isDone()) {
      setHandshakeSuccess();
      return true;
    }
    return false;
  }

  /** Notify all the handshake futures about the successfully handshake */
  private void setHandshakeSuccess() {
    if (handshakePromise.trySuccess(ctx.channel())) {
      ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
    }
  }

  /** Notify all the handshake futures about the failure during the handshake. */
  private void setHandshakeFailure(Throwable cause) {
    // Release all resources such as internal buffers that SSLEngine
    // is managing.
    engine.closeOutbound();

    try {
      engine.closeInbound();
    } catch (SSLException e) {
      // only log in debug mode as it most likely harmless and latest chrome still trigger
      // this all the time.
      //
      // See https://github.com/netty/netty/issues/1340
      String msg = e.getMessage();
      if (msg == null || !msg.contains("possible truncation attack")) {
        logger.debug("SSLEngine.closeInbound() raised an exception.", e);
      }
    }
    notifyHandshakeFailure(cause);
    for (; ; ) {
      PendingWrite write = pendingUnencryptedWrites.poll();
      if (write == null) {
        break;
      }
      write.failAndRecycle(cause);
    }
  }

  private void notifyHandshakeFailure(Throwable cause) {
    if (handshakePromise.tryFailure(cause)) {
      ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
      ctx.close();
    }
  }

  private void closeOutboundAndChannel(
      final ChannelHandlerContext ctx, final ChannelPromise promise, boolean disconnect)
      throws Exception {
    if (!ctx.channel().isActive()) {
      if (disconnect) {
        ctx.disconnect(promise);
      } else {
        ctx.close(promise);
      }
      return;
    }

    engine.closeOutbound();

    ChannelPromise closeNotifyFuture = ctx.newPromise();
    write(ctx, Unpooled.EMPTY_BUFFER, closeNotifyFuture);
    flush(ctx);
    safeClose(ctx, closeNotifyFuture, promise);
  }

  @Override
  public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
    this.ctx = ctx;

    if (ctx.channel().isActive()) {
      // channelActive() event has been fired already, which means this.channelActive() will
      // not be invoked. We have to initialize here instead.
      handshake();
    } else {
      // channelActive() event has not been fired yet.  this.channelOpen() will be invoked
      // and initialization will occur there.
    }
  }

  private Future<Channel> handshake() {
    final ScheduledFuture<?> timeoutFuture;
    if (handshakeTimeoutMillis > 0) {
      timeoutFuture =
          ctx.executor()
              .schedule(
                  new Runnable() {
                    @Override
                    public void run() {
                      if (handshakePromise.isDone()) {
                        return;
                      }
                      notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
                    }
                  },
                  handshakeTimeoutMillis,
                  TimeUnit.MILLISECONDS);
    } else {
      timeoutFuture = null;
    }

    handshakePromise.addListener(
        new GenericFutureListener<Future<Channel>>() {
          @Override
          public void operationComplete(Future<Channel> f) throws Exception {
            if (timeoutFuture != null) {
              timeoutFuture.cancel(false);
            }
          }
        });
    try {
      engine.beginHandshake();
      wrapNonAppData(ctx, false);
      ctx.flush();
    } catch (Exception e) {
      notifyHandshakeFailure(e);
    }
    return handshakePromise;
  }

  /** Issues a SSL handshake once connected when used in client-mode */
  @Override
  public void channelActive(final ChannelHandlerContext ctx) throws Exception {
    if (!startTls && engine.getUseClientMode()) {
      // issue and handshake and add a listener to it which will fire an exception event if
      // an exception was thrown while doing the handshake
      handshake()
          .addListener(
              new GenericFutureListener<Future<Channel>>() {
                @Override
                public void operationComplete(Future<Channel> future) throws Exception {
                  if (!future.isSuccess()) {
                    logger.debug("Failed to complete handshake", future.cause());
                    ctx.close();
                  }
                }
              });
    }
    ctx.fireChannelActive();
  }

  private void safeClose(
      final ChannelHandlerContext ctx, ChannelFuture flushFuture, final ChannelPromise promise) {
    if (!ctx.channel().isActive()) {
      ctx.close(promise);
      return;
    }

    final ScheduledFuture<?> timeoutFuture;
    if (closeNotifyTimeoutMillis > 0) {
      // Force-close the connection if close_notify is not fully sent in time.
      timeoutFuture =
          ctx.executor()
              .schedule(
                  new Runnable() {
                    @Override
                    public void run() {
                      logger.warn(
                          ctx.channel()
                              + " last write attempt timed out."
                              + " Force-closing the connection.");
                      ctx.close(promise);
                    }
                  },
                  closeNotifyTimeoutMillis,
                  TimeUnit.MILLISECONDS);
    } else {
      timeoutFuture = null;
    }

    // Close the connection if close_notify is sent in time.
    flushFuture.addListener(
        new ChannelFutureListener() {
          @Override
          public void operationComplete(ChannelFuture f) throws Exception {
            if (timeoutFuture != null) {
              timeoutFuture.cancel(false);
            }
            // Trigger the close in all cases to make sure the promise is notified
            // See https://github.com/netty/netty/issues/2358
            ctx.close(promise);
          }
        });
  }

  private final class LazyChannelPromise extends DefaultPromise<Channel> {

    @Override
    protected EventExecutor executor() {
      if (ctx == null) {
        throw new IllegalStateException();
      }
      return ctx.executor();
    }
  }
}
Пример #17
0
/**
 * A portable CDI extension which registers beans for lettuce. If there are no RedisURIs there are
 * also no registrations for RedisClients.
 *
 * @author <a href="mailto:[email protected]">Mark Paluch</a>
 */
public class LettuceCdiExtension implements Extension {

  private static final InternalLogger LOGGER =
      InternalLoggerFactory.getInstance(LettuceCdiExtension.class);

  private final Map<Set<Annotation>, Bean<RedisURI>> redisUris =
      new HashMap<Set<Annotation>, Bean<RedisURI>>();

  public LettuceCdiExtension() {
    LOGGER.info("Activating CDI extension for lettuce.");
  }

  /**
   * Implementation of a an observer which checks for RedisURI beans and stores them in {@link
   * #redisUris} for later association with corresponding repository beans.
   *
   * @param <T> The type.
   * @param processBean The annotated type as defined by CDI.
   */
  @SuppressWarnings("unchecked")
  <T> void processBean(@Observes ProcessBean<T> processBean) {
    Bean<T> bean = processBean.getBean();
    for (Type type : bean.getTypes()) {
      // Check if the bean is an RedisURI.
      if (type instanceof Class<?> && RedisURI.class.isAssignableFrom((Class<?>) type)) {
        Set<Annotation> qualifiers = new HashSet<Annotation>(bean.getQualifiers());
        if (bean.isAlternative() || !redisUris.containsKey(qualifiers)) {
          LOGGER.debug(
              String.format(
                  "Discovered '%s' with qualifiers %s.", RedisURI.class.getName(), qualifiers));
          redisUris.put(qualifiers, (Bean<RedisURI>) bean);
        }
      }
    }
  }

  /**
   * Implementation of a an observer which registers beans to the CDI container for the detected
   * RedisURIs.
   *
   * <p>The repository beans are associated to the EntityManagers using their qualifiers.
   *
   * @param beanManager The BeanManager instance.
   */
  void afterBeanDiscovery(
      @Observes AfterBeanDiscovery afterBeanDiscovery, BeanManager beanManager) {

    int counter = 0;
    for (Entry<Set<Annotation>, Bean<RedisURI>> entry : redisUris.entrySet()) {

      Bean<RedisURI> redisUri = entry.getValue();
      Set<Annotation> qualifiers = entry.getKey();

      String clientBeanName = RedisClient.class.getSimpleName();
      String clusterClientBeanName = RedisClusterClient.class.getSimpleName();
      if (!contains(qualifiers, Default.class)) {
        clientBeanName += counter;
        clusterClientBeanName += counter;
        counter++;
      }

      RedisClientCdiBean clientBean =
          new RedisClientCdiBean(beanManager, qualifiers, redisUri, clientBeanName);
      register(afterBeanDiscovery, qualifiers, clientBean);

      RedisClusterClientCdiBean clusterClientBean =
          new RedisClusterClientCdiBean(beanManager, qualifiers, redisUri, clusterClientBeanName);
      register(afterBeanDiscovery, qualifiers, clusterClientBean);
    }
  }

  private boolean contains(Set<Annotation> qualifiers, Class<Default> defaultClass) {
    Optional<Annotation> result =
        Iterables.tryFind(
            qualifiers,
            new Predicate<Annotation>() {
              @Override
              public boolean apply(Annotation input) {
                return input instanceof Default;
              }
            });
    return result.isPresent();
  }

  private void register(
      AfterBeanDiscovery afterBeanDiscovery, Set<Annotation> qualifiers, Bean<?> bean) {
    LOGGER.info(
        String.format(
            "Registering bean '%s' with qualifiers %s.",
            bean.getBeanClass().getName(), qualifiers));
    afterBeanDiscovery.addBean(bean);
  }
}
Пример #18
0
/**
 * Reads a PEM file and converts it into a list of DERs so that they are imported into a {@link
 * KeyStore} easily.
 */
final class PemReader {

  private static final InternalLogger logger = InternalLoggerFactory.getInstance(PemReader.class);

  private static final Pattern CERT_PATTERN =
      Pattern.compile(
          "-+BEGIN\\s+.*CERTIFICATE[^-]*-+(?:\\s|\\r|\\n)+"
              + // Header
              "([a-z0-9+/=\\r\\n]+)"
              + // Base64 text
              "-+END\\s+.*CERTIFICATE[^-]*-+", // Footer
          Pattern.CASE_INSENSITIVE);
  private static final Pattern KEY_PATTERN =
      Pattern.compile(
          "-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+"
              + // Header
              "([a-z0-9+/=\\r\\n]+)"
              + // Base64 text
              "-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer
          Pattern.CASE_INSENSITIVE);

  static ByteBuf[] readCertificates(File file) throws CertificateException {
    String content;
    try {
      content = readContent(file);
    } catch (IOException e) {
      throw new CertificateException("failed to read a file: " + file, e);
    }

    List<ByteBuf> certs = new ArrayList<ByteBuf>();
    Matcher m = CERT_PATTERN.matcher(content);
    int start = 0;
    for (; ; ) {
      if (!m.find(start)) {
        break;
      }

      ByteBuf base64 = Unpooled.copiedBuffer(m.group(1), CharsetUtil.US_ASCII);
      ByteBuf der = Base64.decode(base64);
      base64.release();
      certs.add(der);

      start = m.end();
    }

    if (certs.isEmpty()) {
      throw new CertificateException("found no certificates: " + file);
    }

    return certs.toArray(new ByteBuf[certs.size()]);
  }

  static ByteBuf readPrivateKey(File file) throws KeyException {
    String content;
    try {
      content = readContent(file);
    } catch (IOException e) {
      throw new KeyException("failed to read a file: " + file, e);
    }

    Matcher m = KEY_PATTERN.matcher(content);
    if (!m.find()) {
      throw new KeyException("found no private key: " + file);
    }

    ByteBuf base64 = Unpooled.copiedBuffer(m.group(1), CharsetUtil.US_ASCII);
    ByteBuf der = Base64.decode(base64);
    base64.release();
    return der;
  }

  private static String readContent(File file) throws IOException {
    InputStream in = new FileInputStream(file);
    ByteArrayOutputStream out = new ByteArrayOutputStream();
    try {
      byte[] buf = new byte[8192];
      for (; ; ) {
        int ret = in.read(buf);
        if (ret < 0) {
          break;
        }
        out.write(buf, 0, ret);
      }
      return out.toString(CharsetUtil.US_ASCII.name());
    } finally {
      safeClose(in);
      safeClose(out);
    }
  }

  private static void safeClose(InputStream in) {
    try {
      in.close();
    } catch (IOException e) {
      logger.warn("Failed to close a stream.", e);
    }
  }

  private static void safeClose(OutputStream out) {
    try {
      out.close();
    } catch (IOException e) {
      logger.warn("Failed to close a stream.", e);
    }
  }

  private PemReader() {}
}
Пример #19
0
/** @author circlespainter */
public class AutoWebActorHandler extends WebActorHandler {
  private static final InternalLogger log =
      InternalLoggerFactory.getInstance(AutoWebActorHandler.class);
  private static final List<Class<?>> actorClasses = new ArrayList<>(32);
  private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0];

  public AutoWebActorHandler() {
    this(null, null, null, null);
  }

  public AutoWebActorHandler(List<String> packagePrefixes) {
    this(null, null, packagePrefixes, null);
  }

  public AutoWebActorHandler(String httpResponseEncoderName, List<String> packagePrefixes) {
    this(httpResponseEncoderName, null, packagePrefixes, null);
  }

  public AutoWebActorHandler(String httpResponseEncoderName) {
    this(httpResponseEncoderName, null, null, null);
  }

  public AutoWebActorHandler(String httpResponseEncoderName, ClassLoader userClassLoader) {
    this(httpResponseEncoderName, userClassLoader, null, null);
  }

  public AutoWebActorHandler(
      String httpResponseEncoderName, ClassLoader userClassLoader, List<String> packagePrefixes) {
    this(httpResponseEncoderName, userClassLoader, packagePrefixes, null);
  }

  public AutoWebActorHandler(String httpResponseEncoderName, Map<Class<?>, Object[]> actorParams) {
    this(httpResponseEncoderName, null, null, actorParams);
  }

  public AutoWebActorHandler(
      String httpResponseEncoderName,
      List<String> packagePrefixes,
      Map<Class<?>, Object[]> actorParams) {
    this(httpResponseEncoderName, null, packagePrefixes, actorParams);
  }

  public AutoWebActorHandler(
      String httpResponseEncoderName,
      ClassLoader userClassLoader,
      List<String> packagePrefixes,
      Map<Class<?>, Object[]> actorParams) {
    super(null, httpResponseEncoderName);
    super.contextProvider =
        newContextProvider(
            userClassLoader != null ? userClassLoader : ClassLoader.getSystemClassLoader(),
            packagePrefixes,
            actorParams);
  }

  public AutoWebActorHandler(String httpResponseEncoderName, AutoContextProvider prov) {
    super(prov, httpResponseEncoderName);
  }

  protected AutoContextProvider newContextProvider(
      ClassLoader userClassLoader,
      List<String> packagePrefixes,
      Map<Class<?>, Object[]> actorParams) {
    return new AutoContextProvider(userClassLoader, packagePrefixes, actorParams);
  }

  public static class AutoContextProvider implements WebActorContextProvider {
    private final ClassLoader userClassLoader;
    private final List<String> packagePrefixes;
    private final Map<Class<?>, Object[]> actorParams;
    private final Long defaultContextValidityMS;

    public AutoContextProvider(
        ClassLoader userClassLoader,
        List<String> packagePrefixes,
        Map<Class<?>, Object[]> actorParams) {
      this(userClassLoader, packagePrefixes, actorParams, null);
    }

    public AutoContextProvider(
        ClassLoader userClassLoader,
        List<String> packagePrefixes,
        Map<Class<?>, Object[]> actorParams,
        Long defaultContextValidityMS) {
      this.userClassLoader = userClassLoader;
      this.packagePrefixes = packagePrefixes;
      this.actorParams = actorParams;
      this.defaultContextValidityMS = defaultContextValidityMS;
    }

    @Override
    public final Context get(final FullHttpRequest req) {
      final String sessionId = getSessionId(req);
      if (sessionId != null && sessionsEnabled()) {
        final Context actorContext = sessions.get(sessionId);
        if (actorContext != null) {
          if (actorContext.renew()) return actorContext;
          else sessions.remove(sessionId); // Evict session
        }
      }
      return newActorContext(req);
    }

    protected AutoContext newActorContext(FullHttpRequest req) {
      final AutoContext c = new AutoContext(req, packagePrefixes, actorParams, userClassLoader);
      if (defaultContextValidityMS != null) c.setValidityMS(defaultContextValidityMS);
      return c;
    }

    private String getSessionId(FullHttpRequest req) {
      final Set<Cookie> cookies = NettyHttpRequest.getNettyCookies(req);
      if (cookies != null) {
        for (final Cookie c : cookies) {
          if (c != null && SESSION_COOKIE_KEY.equals(c.name())) return c.value();
        }
      }
      return null;
    }
  }

  private static class AutoContext extends DefaultContextImpl {
    private String id;

    private final List<String> packagePrefixes;
    private final Map<Class<?>, Object[]> actorParams;
    private final ClassLoader userClassLoader;
    private Class<? extends ActorImpl<? extends WebMessage>> actorClass;
    private ActorRef<? extends WebMessage> actorRef;

    public AutoContext(
        FullHttpRequest req,
        List<String> packagePrefixes,
        Map<Class<?>, Object[]> actorParams,
        ClassLoader userClassLoader) {
      this.packagePrefixes = packagePrefixes;
      this.actorParams = actorParams;
      this.userClassLoader = userClassLoader;
      fillActor(req);
    }

    private void fillActor(FullHttpRequest req) {
      final Pair<ActorRef<? extends WebMessage>, Class<? extends ActorImpl<? extends WebMessage>>>
          p = autoCreateActor(req);
      if (p != null) {
        actorRef = p.getFirst();
        actorClass = p.getSecond();
      }
    }

    @Override
    public final String getId() {
      return id != null ? id : (id = UUID.randomUUID().toString());
    }

    @Override
    public final void restart(FullHttpRequest req) {
      renewed = new Date().getTime();
      fillActor(req);
    }

    @Override
    public final ActorRef<? extends WebMessage> getWebActor() {
      return actorRef;
    }

    @Override
    public final boolean handlesWithHttp(String uri) {
      return WebActorHandler.handlesWithHttp(uri, actorClass);
    }

    @Override
    public final boolean handlesWithWebSocket(String uri) {
      return WebActorHandler.handlesWithWebSocket(uri, actorClass);
    }

    @Override
    public WatchPolicy watch() {
      return WatchPolicy.DIE_IF_EXCEPTION_ELSE_RESTART;
    }

    @SuppressWarnings("unchecked")
    private Pair<ActorRef<? extends WebMessage>, Class<? extends ActorImpl<? extends WebMessage>>>
        autoCreateActor(FullHttpRequest req) {
      registerActorClasses();
      final String uri = req.getUri();
      for (final Class<?> c : actorClasses) {
        if (WebActorHandler.handlesWithHttp(uri, c) || WebActorHandler.handlesWithWebSocket(uri, c))
          return new Pair<
              ActorRef<? extends WebMessage>, Class<? extends ActorImpl<? extends WebMessage>>>(
              Actor.newActor(
                      new ActorSpec(
                          c, actorParams != null ? actorParams.get(c) : EMPTY_OBJECT_ARRAY))
                  .spawn(),
              (Class<? extends ActorImpl<? extends WebMessage>>) c);
      }
      return null;
    }

    private synchronized void registerActorClasses() {
      if (actorClasses.isEmpty()) {
        try {
          final ClassLoader classLoader =
              userClassLoader != null ? userClassLoader : this.getClass().getClassLoader();
          ClassLoaderUtil.accept(
              (URLClassLoader) classLoader,
              new ClassLoaderUtil.Visitor() {
                @Override
                public final void visit(String resource, URL url, ClassLoader cl) {
                  if (packagePrefixes != null) {
                    boolean found = false;
                    for (final String packagePrefix : packagePrefixes) {
                      if (packagePrefix != null
                          && resource.startsWith(packagePrefix.replace('.', '/'))) {
                        found = true;
                        break;
                      }
                    }
                    if (!found) return;
                  }
                  if (!ClassLoaderUtil.isClassFile(resource)) return;
                  final String className = ClassLoaderUtil.resourceToClass(resource);
                  try (final InputStream is = cl.getResourceAsStream(resource)) {
                    if (AnnotationUtil.hasClassAnnotation(WebActor.class, is))
                      registerWebActor(cl.loadClass(className));
                  } catch (final IOException | ClassNotFoundException e) {
                    log.error(
                        "Exception while scanning class " + className + " for WebActor annotation",
                        e);
                    throw new RuntimeException(e);
                  }
                }

                private void registerWebActor(Class<?> c) {
                  actorClasses.add(c);
                }
              });
        } catch (final IOException e) {
          log.error("IOException while scanning classes for WebActor annotation", e);
        }
      }
    }
  }
}
Пример #20
0
/**
 * {@link ServerSocketChannel} which accepts new connections and create the {@link
 * OioSocketChannel}'s for them.
 *
 * <p>This implementation use Old-Blocking-IO.
 */
public class OioServerSocketChannel extends AbstractOioMessageChannel
    implements ServerSocketChannel {

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(OioServerSocketChannel.class);

  private static final ChannelMetadata METADATA = new ChannelMetadata(false);

  private static ServerSocket newServerSocket() {
    try {
      return new ServerSocket();
    } catch (IOException e) {
      throw new ChannelException("failed to create a server socket", e);
    }
  }

  final ServerSocket socket;
  final Lock shutdownLock = new ReentrantLock();
  private final OioServerSocketChannelConfig config;

  /** Create a new instance with an new {@link Socket} */
  public OioServerSocketChannel() {
    this(newServerSocket());
  }

  /**
   * Create a new instance from the given {@link ServerSocket}
   *
   * @param socket the {@link ServerSocket} which is used by this instance
   */
  public OioServerSocketChannel(ServerSocket socket) {
    super(null);
    if (socket == null) {
      throw new NullPointerException("socket");
    }

    boolean success = false;
    try {
      socket.setSoTimeout(SO_TIMEOUT);
      success = true;
    } catch (IOException e) {
      throw new ChannelException("Failed to set the server socket timeout.", e);
    } finally {
      if (!success) {
        try {
          socket.close();
        } catch (IOException e) {
          if (logger.isWarnEnabled()) {
            logger.warn("Failed to close a partially initialized socket.", e);
          }
        }
      }
    }
    this.socket = socket;
    config = new DefaultOioServerSocketChannelConfig(this, socket);
  }

  @Override
  public InetSocketAddress localAddress() {
    return (InetSocketAddress) super.localAddress();
  }

  @Override
  public ChannelMetadata metadata() {
    return METADATA;
  }

  @Override
  public OioServerSocketChannelConfig config() {
    return config;
  }

  @Override
  public InetSocketAddress remoteAddress() {
    return null;
  }

  @Override
  public boolean isOpen() {
    return !socket.isClosed();
  }

  @Override
  public boolean isActive() {
    return isOpen() && socket.isBound();
  }

  @Override
  protected SocketAddress localAddress0() {
    return socket.getLocalSocketAddress();
  }

  @Override
  protected void doBind(SocketAddress localAddress) throws Exception {
    socket.bind(localAddress, config.getBacklog());
  }

  @Override
  protected void doClose() throws Exception {
    socket.close();
  }

  @Override
  protected int doReadMessages(MessageList<Object> buf) throws Exception {
    if (socket.isClosed()) {
      return -1;
    }

    try {
      Socket s = socket.accept();
      try {
        if (s != null) {
          buf.add(new OioSocketChannel(this, s));
          return 1;
        }
      } catch (Throwable t) {
        logger.warn("Failed to create a new channel from an accepted socket.", t);
        if (s != null) {
          try {
            s.close();
          } catch (Throwable t2) {
            logger.warn("Failed to close a socket.", t2);
          }
        }
      }
    } catch (SocketTimeoutException e) {
      // Expected
    }
    return 0;
  }

  @Override
  protected int doWrite(MessageList<Object> msgs, int index) throws Exception {
    throw new UnsupportedOperationException();
  }

  @Override
  protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress)
      throws Exception {
    throw new UnsupportedOperationException();
  }

  @Override
  protected SocketAddress remoteAddress0() {
    return null;
  }

  @Override
  protected void doDisconnect() throws Exception {
    throw new UnsupportedOperationException();
  }
}
Пример #21
0
/** @author circlespainter */
public class WebActorHandler extends SimpleChannelInboundHandler<Object> {
  // @FunctionalInterface
  public interface WebActorContextProvider {
    Context get(ChannelHandlerContext ctx, FullHttpRequest req);
  }

  public interface Context {
    boolean isValid();

    void invalidate();

    ActorRef<? extends WebMessage> getRef();

    ReentrantLock getLock();

    Map<String, Object> getAttachments();

    boolean handlesWithHttp(String uri);

    boolean handlesWithWebSocket(String uri);

    boolean watch();
  }

  public abstract static class DefaultContextImpl implements Context {
    private static final String durationProp =
        System.getProperty(DefaultContextImpl.class.getName() + ".durationMillis");
    private static final long DURATION =
        durationProp != null ? Long.parseLong(durationProp) : 60_000L;
    private final ReentrantLock lock = new ReentrantLock();
    private final long created;
    private final Map<String, Object> attachments = new HashMap<>();

    private boolean valid = true;

    public DefaultContextImpl() {
      created = new Date().getTime();
    }

    @Override
    public void invalidate() {
      attachments.clear();
      valid = false;
    }

    @Override
    public final boolean isValid() {
      final boolean ret = valid && (new Date().getTime() - created) <= DURATION;
      if (!ret) invalidate();
      return ret;
    }

    @Override
    public final Map<String, Object> getAttachments() {
      return attachments;
    }

    @Override
    public final ReentrantLock getLock() {
      return lock;
    }

    @Override
    public boolean watch() {
      return true;
    }
  }

  public WebActorHandler(WebActorContextProvider selector) {
    this(selector, null);
  }

  public WebActorHandler(WebActorContextProvider selector, String httpResponseEncoderName) {
    this.selector = selector;
    this.httpResponseEncoderName = httpResponseEncoderName;
  }

  @Override
  public final void channelReadComplete(ChannelHandlerContext ctx) {
    ctx.flush();
  }

  @Override
  public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
    if (ctx.channel().isOpen()) ctx.close();
    log.error("Exception caught", cause);
  }

  @Override
  protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
    if (msg instanceof FullHttpRequest) {
      handleHttpRequest(ctx, (FullHttpRequest) msg);
    } else if (msg instanceof WebSocketFrame) {
      handleWebSocketFrame(ctx, (WebSocketFrame) msg);
    } else {
      throw new AssertionError("Unexpected message " + msg);
    }
  }

  protected static boolean sessionsEnabled() {
    return "always".equals(trackSession) || "sse".equals(trackSession);
  }

  protected static final String SESSION_COOKIE_KEY = "JSESSIONID";
  protected static final Map<String, Context> sessions =
      Collections.synchronizedMap(new WeakHashMap<String, Context>());
  protected static final String TRACK_SESSION_PROP =
      HttpChannelAdapter.class.getName() + ".trackSession";
  protected static final String trackSession = System.getProperty(TRACK_SESSION_PROP, "sse");

  protected static final String OMIT_DATE_HEADER_PROP =
      HttpChannelAdapter.class.getName() + ".omitDateHeader";
  protected static final Boolean omitDateHeader =
      SystemProperties.isEmptyOrTrue(OMIT_DATE_HEADER_PROP);

  private static final String ACTOR_KEY = "co.paralleluniverse.comsat.webactors.sessionActor";

  private static final WeakHashMap<Class<?>, List<Pair<String, String>>> classToUrlPatterns =
      new WeakHashMap<>();
  private static final InternalLogger log =
      InternalLoggerFactory.getInstance(AutoWebActorHandler.class);

  private final WebActorContextProvider selector;
  private final String httpResponseEncoderName;

  private WebSocketServerHandshaker handshaker;
  private WebSocketActorAdapter webSocketActor;

  private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
    // Check for closing frame
    if (frame instanceof CloseWebSocketFrame) {
      handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
      return;
    }

    if (frame instanceof PingWebSocketFrame) {
      ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
      return;
    }

    if (frame instanceof ContinuationWebSocketFrame) return;

    if (frame instanceof TextWebSocketFrame)
      webSocketActor.onMessage(((TextWebSocketFrame) frame).text());
    else webSocketActor.onMessage(frame.content().nioBuffer());
  }

  private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req)
      throws SuspendExecution {
    // Handle a bad request.
    if (!req.getDecoderResult().isSuccess()) {
      sendHttpResponse(
          ctx, req, new DefaultFullHttpResponse(req.getProtocolVersion(), BAD_REQUEST), false);
      return;
    }

    final String uri = req.getUri();

    final Context actorCtx = selector.get(ctx, req);
    assert actorCtx != null;

    final ReentrantLock lock = actorCtx.getLock();
    assert lock != null;

    lock.lock();

    try {
      final ActorRef<? extends WebMessage> userActorRef = actorCtx.getRef();
      ActorImpl internalActor = (ActorImpl) actorCtx.getAttachments().get(ACTOR_KEY);

      if (userActorRef != null) {
        if (actorCtx.handlesWithWebSocket(uri)) {
          if (internalActor == null || !(internalActor instanceof WebSocketActorAdapter)) {
            //noinspection unchecked
            webSocketActor =
                new WebSocketActorAdapter(ctx, (ActorRef<? super WebMessage>) userActorRef);
            addActorToContextAndUnlock(actorCtx, webSocketActor, lock);
          }
          // Handshake
          final WebSocketServerHandshakerFactory wsFactory =
              new WebSocketServerHandshakerFactory(uri, null, true);
          handshaker = wsFactory.newHandshaker(req);
          if (handshaker == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
          } else {
            @SuppressWarnings("unchecked")
            final ActorRef<WebMessage> userActorRef0 =
                (ActorRef<WebMessage>) webSocketActor.userActor;
            handshaker
                .handshake(ctx.channel(), req)
                .addListener(
                    new GenericFutureListener<ChannelFuture>() {
                      @Override
                      public void operationComplete(ChannelFuture future) throws Exception {
                        FiberUtil.runInFiber(
                            new SuspendableRunnable() {
                              @Override
                              public void run() throws SuspendExecution, InterruptedException {
                                userActorRef0.send(
                                    new WebSocketOpened(WebActorHandler.this.webSocketActor.ref()));
                              }
                            });
                      }
                    });
          }
          return;
        } else if (actorCtx.handlesWithHttp(uri)) {
          if (internalActor == null || !(internalActor instanceof HttpActorAdapter)) {
            //noinspection unchecked
            internalActor =
                new HttpActorAdapter(
                    (ActorRef<HttpRequest>) userActorRef, actorCtx, httpResponseEncoderName);
            addActorToContextAndUnlock(actorCtx, internalActor, lock);
          }
          //noinspection unchecked
          ((HttpActorAdapter) internalActor).service(ctx, req);
          return;
        }
      }
    } finally {
      if (lock.isHeldByCurrentStrand() && lock.isLocked()) lock.unlock();
    }

    sendHttpResponse(
        ctx, req, new DefaultFullHttpResponse(req.getProtocolVersion(), NOT_FOUND), false);
  }

  private void addActorToContextAndUnlock(
      Context actorContext, ActorImpl actor, ReentrantLock lock) {
    actorContext.getAttachments().put(ACTOR_KEY, actor);
    lock.unlock();
  }

  private static final class WebSocketActorAdapter extends FakeActor<WebDataMessage> {
    ActorRef<? super WebMessage> userActor;

    private ChannelHandlerContext ctx;

    public WebSocketActorAdapter(
        ChannelHandlerContext ctx, ActorRef<? super WebMessage> userActor) {
      super(userActor.getName(), new WebSocketChannelAdapter(ctx));
      ((WebSocketChannelAdapter) (SendPort) getMailbox()).actor = this;
      this.ctx = ctx;
      this.userActor = userActor;
      watch(userActor);
    }

    @Override
    public final void interrupt() {
      die(new InterruptedException());
    }

    @Override
    public final String toString() {
      return "WebSocketActorAdapter{" + "userActor=" + userActor + '}';
    }

    private void onMessage(final ByteBuffer message) {
      try {
        userActor.send(new WebDataMessage(ref(), message));
      } catch (SuspendExecution ex) {
        throw new AssertionError(ex);
      }
    }

    private void onMessage(final String message) {
      try {
        userActor.send(new WebDataMessage(ref(), message));
      } catch (SuspendExecution ex) {
        throw new AssertionError(ex);
      }
    }

    @Override
    protected final WebDataMessage handleLifecycleMessage(LifecycleMessage m) {
      if (m instanceof ExitMessage) {
        ExitMessage em = (ExitMessage) m;
        if (em.getActor() != null && em.getActor().equals(userActor)) die(em.getCause());
      }
      return null;
    }

    @Override
    protected final void throwIn(RuntimeException e) {
      die(e);
    }

    @Override
    protected final void die(Throwable cause) {
      super.die(cause);
      if (ctx.channel().isOpen()) ctx.close();

      // Ensure to release server references
      userActor = null;
      ctx = null;
    }
  }

  private static final class WebSocketChannelAdapter implements SendPort<WebDataMessage> {
    private final ChannelHandlerContext ctx;

    WebSocketActorAdapter actor;

    public WebSocketChannelAdapter(ChannelHandlerContext ctx) {
      this.ctx = ctx;
    }

    @Override
    public final void send(WebDataMessage message) throws SuspendExecution, InterruptedException {
      trySend(message);
    }

    @Override
    public final boolean send(WebDataMessage message, long timeout, TimeUnit unit)
        throws SuspendExecution, InterruptedException {
      return trySend(message);
    }

    @Override
    public final boolean send(WebDataMessage message, Timeout timeout)
        throws SuspendExecution, InterruptedException {
      return send(message, timeout.nanosLeft(), TimeUnit.NANOSECONDS);
    }

    @Override
    public final boolean trySend(WebDataMessage message) {
      if (!message.isBinary()) ctx.writeAndFlush(new TextWebSocketFrame(message.getStringBody()));
      else
        ctx.writeAndFlush(
            new BinaryWebSocketFrame(Unpooled.wrappedBuffer(message.getByteBufferBody())));
      return true;
    }

    @Override
    public final void close() {
      if (ctx.channel().isOpen()) ctx.close();
      if (actor != null) actor.die(null);
    }

    @Override
    public final void close(Throwable t) {
      if (actor != null) actor.die(t);
      close();
    }
  }

  private static final class HttpActorAdapter extends FakeActor<HttpResponse> {
    private ActorRef<? super HttpRequest> userActor;
    private Context context;

    private volatile boolean dead;

    private volatile ChannelHandlerContext ctx;
    private volatile FullHttpRequest req;
    private volatile Object watchToken;

    HttpActorAdapter(
        ActorRef<? super HttpRequest> userActor,
        Context actorContext,
        String httpResponseEncoderName) {
      super("HttpActorAdapter", new HttpChannelAdapter(actorContext, httpResponseEncoderName));

      if (actorContext.watch()) ((HttpChannelAdapter) (SendPort) getMailbox()).actor = this;

      this.userActor = userActor;
      this.context = actorContext;
    }

    final void service(ChannelHandlerContext ctx, FullHttpRequest req) throws SuspendExecution {
      if (context.watch()) watchToken = watch(userActor);

      this.ctx = ctx;
      this.req = req;

      if (isDone()) {
        handleDeath(getDeathCause());
        return;
      }

      userActor.send(new HttpRequestWrapper(ref(), ctx, req));
    }

    final void unwatch() {
      if (watchToken != null && userActor != null) {
        unwatch(userActor, watchToken);
        watchToken = null;
      }
    }

    private void handleDeath(Throwable cause) {
      if (cause != null)
        sendHttpResponse(
            ctx,
            req,
            new DefaultFullHttpResponse(
                req.getProtocolVersion(),
                INTERNAL_SERVER_ERROR,
                Unpooled.wrappedBuffer(
                    ("Actor is dead because of " + cause.getMessage()).getBytes())),
            false);
      else
        sendHttpResponse(
            ctx,
            req,
            new DefaultFullHttpResponse(
                req.getProtocolVersion(),
                INTERNAL_SERVER_ERROR,
                Unpooled.wrappedBuffer(("Actor has terminated.").getBytes())),
            false);
    }

    @Override
    protected final HttpResponse handleLifecycleMessage(LifecycleMessage m) {
      if (m instanceof ExitMessage) {
        final ExitMessage em = (ExitMessage) m;
        if (em.getActor() != null && em.getActor().equals(userActor)) {
          handleDeath(em.getCause());
          die(em.getCause());
        }
      }
      return null;
    }

    @Override
    protected final void die(Throwable cause) {
      if (dead) return;
      dead = true;
      super.die(cause);
      try {
        context.invalidate();
      } catch (final Exception ignored) {
      }

      // Ensure to release references to server objects
      unwatch();
      userActor = null;
      watchToken = null;
      context = null;
      ctx = null;
      req = null;
    }

    @Override
    protected final void throwIn(RuntimeException e) {
      die(e);
    }

    @Override
    protected final void interrupt() {
      die(new InterruptedException());
    }

    @Override
    public final String toString() {
      return "HttpActorAdapter{" + userActor + "}";
    }
  }

  private static final class HttpChannelAdapter implements SendPort<HttpResponse> {
    HttpActorAdapter actor;

    private final String httpResponseEncoderName;

    private Context actorContext;

    public HttpChannelAdapter(Context actorContext, String httpResponseEncoderName) {
      this.actorContext = actorContext;
      this.httpResponseEncoderName = httpResponseEncoderName;
    }

    @Override
    public final void send(HttpResponse message) throws SuspendExecution, InterruptedException {
      trySend(message);
    }

    @Override
    public final boolean send(HttpResponse message, long timeout, TimeUnit unit)
        throws SuspendExecution, InterruptedException {
      send(message);
      return true;
    }

    @Override
    public final boolean send(HttpResponse message, Timeout timeout)
        throws SuspendExecution, InterruptedException {
      return send(message, timeout.nanosLeft(), TimeUnit.NANOSECONDS);
    }

    @Override
    public final boolean trySend(HttpResponse message) {
      try {
        final HttpRequestWrapper nettyRequest = (HttpRequestWrapper) message.getRequest();
        final FullHttpRequest req = nettyRequest.req;
        final ChannelHandlerContext ctx = nettyRequest.ctx;

        final HttpResponseStatus status = HttpResponseStatus.valueOf(message.getStatus());

        if (message.getStatus() >= 400 && message.getStatus() < 600) {
          sendHttpResponse(
              ctx, req, new DefaultFullHttpResponse(req.getProtocolVersion(), status), false);
          close();
          return true;
        }

        if (message.getRedirectPath() != null) {
          sendHttpRedirect(ctx, req, message.getRedirectPath());
          close();
          return true;
        }

        FullHttpResponse res;
        if (message.getStringBody() != null)
          res =
              new DefaultFullHttpResponse(
                  req.getProtocolVersion(),
                  status,
                  Unpooled.wrappedBuffer(message.getStringBody().getBytes()));
        else if (message.getByteBufferBody() != null)
          res =
              new DefaultFullHttpResponse(
                  req.getProtocolVersion(),
                  status,
                  Unpooled.wrappedBuffer(message.getByteBufferBody()));
        else res = new DefaultFullHttpResponse(req.getProtocolVersion(), status);

        if (message.getCookies() != null) {
          final ServerCookieEncoder enc = ServerCookieEncoder.STRICT;
          for (final Cookie c : message.getCookies())
            HttpHeaders.setHeader(res, COOKIE, enc.encode(getNettyCookie(c)));
        }
        if (message.getHeaders() != null) {
          for (final Map.Entry<String, String> h : message.getHeaders().entries())
            HttpHeaders.setHeader(res, h.getKey(), h.getValue());
        }

        if (message.getContentType() != null) {
          String ct = message.getContentType();
          if (message.getCharacterEncoding() != null)
            ct = ct + "; charset=" + message.getCharacterEncoding().name();
          HttpHeaders.setHeader(res, CONTENT_TYPE, ct);
        }

        final boolean sseStarted = message.shouldStartActor();
        if (trackSession(sseStarted)) {
          final String sessionId = UUID.randomUUID().toString();
          res.headers()
              .add(SET_COOKIE, ServerCookieEncoder.STRICT.encode(SESSION_COOKIE_KEY, sessionId));
          startSession(sessionId, actorContext);
        }
        if (!sseStarted) {
          final String stringBody = message.getStringBody();
          long contentLength = 0L;
          if (stringBody != null) contentLength = stringBody.getBytes().length;
          else {
            final ByteBuffer byteBufferBody = message.getByteBufferBody();
            if (byteBufferBody != null) contentLength = byteBufferBody.remaining();
          }
          res.headers().add(CONTENT_LENGTH, contentLength);
        }

        final HttpStreamActorAdapter httpStreamActorAdapter;
        if (sseStarted)
          // This will copy the request content, which must still be referenceable, doing before the
          // request handler
          // unallocates it (unfortunately it is explicitly reference-counted in Netty)
          httpStreamActorAdapter = new HttpStreamActorAdapter(ctx, req);
        else httpStreamActorAdapter = null;

        sendHttpResponse(ctx, req, res, false);

        if (sseStarted) {
          if (httpResponseEncoderName != null) {
            ctx.pipeline().remove(httpResponseEncoderName);
          } else {
            final ChannelPipeline pl = ctx.pipeline();
            final List<String> handlerKeysToBeRemoved = new ArrayList<>();
            for (final Map.Entry<String, ChannelHandler> e : pl) {
              if (e.getValue() instanceof HttpResponseEncoder)
                handlerKeysToBeRemoved.add(e.getKey());
            }
            for (final String k : handlerKeysToBeRemoved) pl.remove(k);
          }

          try {
            message.getFrom().send(new HttpStreamOpened(httpStreamActorAdapter.ref(), message));
          } catch (SuspendExecution e) {
            throw new AssertionError(e);
          }
        }

        return true;
      } finally {
        if (actor != null) actor.unwatch();
      }
    }

    private io.netty.handler.codec.http.cookie.Cookie getNettyCookie(Cookie c) {
      io.netty.handler.codec.http.cookie.Cookie ret =
          new io.netty.handler.codec.http.cookie.DefaultCookie(c.getName(), c.getValue());
      ret.setDomain(c.getDomain());
      ret.setHttpOnly(c.isHttpOnly());
      ret.setMaxAge(c.getMaxAge());
      ret.setPath(c.getPath());
      ret.setSecure(c.isSecure());
      return ret;
    }

    @Override
    public final void close() {
      if (actor != null) actor.die(null);
      actorContext = null;
    }

    @Override
    public final void close(Throwable t) {
      log.error("Exception while closing HTTP adapter", t);
      if (actor != null) actor.die(t);
    }
  }

  protected static boolean trackSession(boolean sseStarted) {
    return trackSession != null
        && ("always".equals(trackSession) || sseStarted && "sse".equals(trackSession));
  }

  protected static boolean handlesWithHttp(String uri, Class<?> actorClass) {
    return match(uri, actorClass).equals("websocket");
  }

  protected static boolean handlesWithWebSocket(String uri, Class<?> actorClass) {
    return match(uri, actorClass).equals("ws");
  }

  private static final class HttpStreamActorAdapter extends FakeActor<WebDataMessage> {
    private volatile boolean dead;

    public HttpStreamActorAdapter(final ChannelHandlerContext ctx, final FullHttpRequest req) {
      super(req.toString(), new HttpStreamChannelAdapter(ctx, req));
      ((HttpStreamChannelAdapter) (SendPort) getMailbox()).actor = this;
    }

    @Override
    protected WebDataMessage handleLifecycleMessage(LifecycleMessage m) {
      if (m instanceof ShutdownMessage) {
        die(null);
      }
      return null;
    }

    @Override
    protected void throwIn(RuntimeException e) {
      die(e);
    }

    @Override
    public void interrupt() {
      die(new InterruptedException());
    }

    @Override
    protected void die(Throwable cause) {
      if (dead) return;
      this.dead = true;
      mailbox().close();
      super.die(cause);
    }

    @Override
    public String toString() {
      return "HttpStreamActorAdapter{request + " + getName() + "}";
    }
  }

  private static final class HttpStreamChannelAdapter implements SendPort<WebDataMessage> {
    private final Charset encoding;
    private final ChannelHandlerContext ctx;

    HttpStreamActorAdapter actor;

    public HttpStreamChannelAdapter(ChannelHandlerContext ctx, FullHttpRequest req) {
      this.ctx = ctx;
      this.encoding = HttpRequestWrapper.extractCharacterEncodingOrDefault(req.headers());
    }

    @Override
    public final void send(WebDataMessage message) throws SuspendExecution, InterruptedException {
      trySend(message);
    }

    @Override
    public final boolean send(WebDataMessage message, long timeout, TimeUnit unit)
        throws SuspendExecution, InterruptedException {
      send(message);
      return true;
    }

    @Override
    public final boolean send(WebDataMessage message, Timeout timeout)
        throws SuspendExecution, InterruptedException {
      return send(message, timeout.nanosLeft(), TimeUnit.NANOSECONDS);
    }

    @Override
    public final boolean trySend(WebDataMessage res) {
      final ByteBuf buf;
      final String stringBody = res.getStringBody();
      if (stringBody != null) {
        byte[] bs = stringBody.getBytes(encoding);
        buf = Unpooled.wrappedBuffer(bs);
      } else {
        buf = Unpooled.wrappedBuffer(res.getByteBufferBody());
      }
      ctx.writeAndFlush(buf);
      return true;
    }

    @Override
    public final void close() {
      if (ctx.channel().isOpen()) ctx.close();
      if (actor != null) actor.die(null);
    }

    @Override
    public final void close(Throwable t) {
      if (actor != null) actor.die(t);
      close();
    }
  }

  private static void sendHttpResponse(
      ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res, Boolean close) {
    writeHttpResponse(ctx, req, res, close);
  }

  private static void sendHttpRedirect(
      ChannelHandlerContext ctx, FullHttpRequest req, String newUri) {
    final FullHttpResponse res = new DefaultFullHttpResponse(req.getProtocolVersion(), FOUND);
    HttpHeaders.setHeader(res, LOCATION, newUri);
    writeHttpResponse(ctx, req, res, true);
  }

  private static void writeHttpResponse(
      ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res, Boolean close) {
    if (!omitDateHeader && !res.headers().contains(DefaultHttpHeaders.Names.DATE))
      DefaultHttpHeaders.addDateHeader(res, DefaultHttpHeaders.Names.DATE, new Date());

    // Send the response and close the connection if necessary.
    if (!HttpHeaders.isKeepAlive(req) || res.getStatus().code() != 200 || close == null || close) {
      res.headers().set(CONNECTION, HttpHeaders.Values.CLOSE);
      ctx.writeAndFlush(res).addListener(ChannelFutureListener.CLOSE);
    } else {
      res.headers().set(CONNECTION, HttpHeaders.Values.KEEP_ALIVE);
      write(ctx, res);
    }
  }

  private static ChannelFuture write(ChannelHandlerContext ctx, Object res) {
    return ctx.writeAndFlush(res);
    // : ctx.write(res);
  }

  private static String match(String uri, Class<?> actorClass) {
    if (uri != null && actorClass != null) {
      for (final Pair<String, String> e : lookupOrInsert(actorClass)) {
        if (servletMatch(e.getFirst(), uri)) return e.getSecond();
      }
    }
    return "";
  }

  private static List<Pair<String, String>> lookupOrInsert(Class<?> actorClass) {
    if (actorClass != null) {
      final List<Pair<String, String>> lookup = classToUrlPatterns.get(actorClass);
      if (lookup != null) return lookup;
      return insert(actorClass);
    }
    return null;
  }

  private static List<Pair<String, String>> insert(Class<?> actorClass) {
    if (actorClass != null) {
      final WebActor wa = actorClass.getAnnotation(WebActor.class);
      final List<Pair<String, String>> ret = new ArrayList<>(4);
      for (String httpP : wa.httpUrlPatterns()) addPattern(ret, httpP, "websocket");
      for (String wsP : wa.webSocketUrlPatterns()) addPattern(ret, wsP, "ws");
      classToUrlPatterns.put(actorClass, ret);
      return ret;
    }
    return null;
  }

  private static void addPattern(List<Pair<String, String>> ret, String p, String type) {
    if (p != null) {
      @SuppressWarnings("MismatchedQueryAndUpdateOfCollection")
      final Pair<String, String> entry = new Pair<>(p, type);
      if (p.endsWith("*") || p.startsWith("*.") || p.equals("/")) // Wildcard -> end
      ret.add(entry);
      else // Exact -> beginning
      ret.add(0, entry);
    }
  }

  private static boolean servletMatch(String pattern, String uri) {
    // As per servlet spec
    if (pattern != null && uri != null) {
      if (pattern.startsWith("/") && pattern.endsWith("*"))
        return uri.startsWith(pattern.substring(0, pattern.length() - 1));
      if (pattern.startsWith("*.")) return uri.endsWith(pattern.substring(2));
      if (pattern.isEmpty()) return uri.equals("/");
      return pattern.equals("/") || pattern.equals(uri);
    }
    return false;
  }

  private static void startSession(String sessionId, Context actorContext) {
    sessions.put(sessionId, actorContext);
  }
}
class EpollSocketTestPermutation extends SocketTestPermutation {

  static final EpollSocketTestPermutation INSTANCE = new EpollSocketTestPermutation();

  static final EventLoopGroup EPOLL_BOSS_GROUP =
      new EpollEventLoopGroup(BOSSES, new DefaultThreadFactory("testsuite-epoll-boss", true));
  static final EventLoopGroup EPOLL_WORKER_GROUP =
      new EpollEventLoopGroup(WORKERS, new DefaultThreadFactory("testsuite-epoll-worker", true));

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(EpollSocketTestPermutation.class);

  @Override
  public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> socket() {

    List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list =
        combo(serverSocket(), clientSocket());

    list.remove(list.size() - 1); // Exclude NIO x NIO test

    return list;
  }

  @SuppressWarnings("unchecked")
  @Override
  public List<BootstrapFactory<ServerBootstrap>> serverSocket() {
    List<BootstrapFactory<ServerBootstrap>> toReturn =
        new ArrayList<BootstrapFactory<ServerBootstrap>>();
    toReturn.add(
        new BootstrapFactory<ServerBootstrap>() {
          @Override
          public ServerBootstrap newInstance() {
            return new ServerBootstrap()
                .group(EPOLL_BOSS_GROUP, EPOLL_WORKER_GROUP)
                .channel(EpollServerSocketChannel.class);
          }
        });
    if (isServerFastOpen()) {
      toReturn.add(
          new BootstrapFactory<ServerBootstrap>() {
            @Override
            public ServerBootstrap newInstance() {
              ServerBootstrap serverBootstrap =
                  new ServerBootstrap()
                      .group(EPOLL_BOSS_GROUP, EPOLL_WORKER_GROUP)
                      .channel(EpollServerSocketChannel.class);
              serverBootstrap.option(EpollChannelOption.TCP_FASTOPEN, 5);
              return serverBootstrap;
            }
          });
    }
    toReturn.add(
        new BootstrapFactory<ServerBootstrap>() {
          @Override
          public ServerBootstrap newInstance() {
            return new ServerBootstrap()
                .group(nioBossGroup, nioWorkerGroup)
                .channel(NioServerSocketChannel.class);
          }
        });

    return toReturn;
  }

  @SuppressWarnings("unchecked")
  @Override
  public List<BootstrapFactory<Bootstrap>> clientSocket() {
    return Arrays.asList(
        new BootstrapFactory<Bootstrap>() {
          @Override
          public Bootstrap newInstance() {
            return new Bootstrap().group(EPOLL_WORKER_GROUP).channel(EpollSocketChannel.class);
          }
        },
        new BootstrapFactory<Bootstrap>() {
          @Override
          public Bootstrap newInstance() {
            return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class);
          }
        });
  }

  @Override
  public List<TestsuitePermutation.BootstrapComboFactory<Bootstrap, Bootstrap>> datagram() {
    // Make the list of Bootstrap factories.
    @SuppressWarnings("unchecked")
    List<BootstrapFactory<Bootstrap>> bfs =
        Arrays.asList(
            new BootstrapFactory<Bootstrap>() {
              @Override
              public Bootstrap newInstance() {
                return new Bootstrap()
                    .group(nioWorkerGroup)
                    .channelFactory(
                        new ChannelFactory<Channel>() {
                          @Override
                          public Channel newChannel() {
                            return new NioDatagramChannel(InternetProtocolFamily.IPv4);
                          }

                          @Override
                          public String toString() {
                            return NioDatagramChannel.class.getSimpleName() + ".class";
                          }
                        });
              }
            },
            new BootstrapFactory<Bootstrap>() {
              @Override
              public Bootstrap newInstance() {
                return new Bootstrap()
                    .group(EPOLL_WORKER_GROUP)
                    .channel(EpollDatagramChannel.class);
              }
            });
    return combo(bfs, bfs);
  }

  public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>>
      domainSocket() {

    List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list =
        combo(serverDomainSocket(), clientDomainSocket());
    return list;
  }

  public List<BootstrapFactory<ServerBootstrap>> serverDomainSocket() {
    return Collections.<BootstrapFactory<ServerBootstrap>>singletonList(
        new BootstrapFactory<ServerBootstrap>() {
          @Override
          public ServerBootstrap newInstance() {
            return new ServerBootstrap()
                .group(EPOLL_BOSS_GROUP, EPOLL_WORKER_GROUP)
                .channel(EpollServerDomainSocketChannel.class);
          }
        });
  }

  public List<BootstrapFactory<Bootstrap>> clientDomainSocket() {
    return Collections.<BootstrapFactory<Bootstrap>>singletonList(
        new BootstrapFactory<Bootstrap>() {
          @Override
          public Bootstrap newInstance() {
            return new Bootstrap()
                .group(EPOLL_WORKER_GROUP)
                .channel(EpollDomainSocketChannel.class);
          }
        });
  }

  public boolean isServerFastOpen() {
    return AccessController.doPrivileged(
            new PrivilegedAction<Integer>() {
              @Override
              public Integer run() {
                int fastopen = 0;
                File file = new File("/proc/sys/net/ipv4/tcp_fastopen");
                if (file.exists()) {
                  BufferedReader in = null;
                  try {
                    in = new BufferedReader(new FileReader(file));
                    fastopen = Integer.parseInt(in.readLine());
                    if (logger.isDebugEnabled()) {
                      logger.debug("{}: {}", file, fastopen);
                    }
                  } catch (Exception e) {
                    logger.debug("Failed to get TCP_FASTOPEN from: {}", file, e);
                  } finally {
                    if (in != null) {
                      try {
                        in.close();
                      } catch (Exception e) {
                        // Ignored.
                      }
                    }
                  }
                } else {
                  if (logger.isDebugEnabled()) {
                    logger.debug("{}: {} (non-existent)", file, fastopen);
                  }
                }
                return fastopen;
              }
            })
        == 3;
  }

  public static DomainSocketAddress newSocketAddress() {
    try {
      File file = File.createTempFile("netty", "dsocket");
      file.delete();
      return new DomainSocketAddress(file);
    } catch (IOException e) {
      throw new IllegalStateException(e);
    }
  }
}
Пример #23
0
/**
 * Provides the default implementation for processing inbound frame events and delegates to a {@link
 * Http2FrameListener}
 *
 * <p>This class will read HTTP/2 frames and delegate the events to a {@link Http2FrameListener}
 *
 * <p>This interface enforces inbound flow control functionality through {@link
 * Http2LocalFlowController}
 */
public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http2LifecycleManager {
  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(Http2ConnectionHandler.class);
  private final Http2ConnectionDecoder decoder;
  private final Http2ConnectionEncoder encoder;
  private final Http2Settings initialSettings;
  private ChannelFutureListener closeListener;
  private BaseDecoder byteDecoder;

  public Http2ConnectionHandler(boolean server, Http2FrameListener listener) {
    this(new DefaultHttp2Connection(server), listener);
  }

  public Http2ConnectionHandler(Http2Connection connection, Http2FrameListener listener) {
    this(connection, new DefaultHttp2FrameReader(), new DefaultHttp2FrameWriter(), listener);
  }

  public Http2ConnectionHandler(
      Http2Connection connection,
      Http2FrameReader frameReader,
      Http2FrameWriter frameWriter,
      Http2FrameListener listener) {
    initialSettings = null;
    encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter);
    decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader, listener);
  }

  /**
   * Constructor for pre-configured encoder and decoder. Just sets the {@code this} as the {@link
   * Http2LifecycleManager} and builds them.
   */
  public Http2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder) {
    this.initialSettings = null;
    this.decoder = checkNotNull(decoder, "decoder");
    this.encoder = checkNotNull(encoder, "encoder");
    if (encoder.connection() != decoder.connection()) {
      throw new IllegalArgumentException(
          "Encoder and Decoder do not share the same connection object");
    }
  }

  public Http2ConnectionHandler(
      Http2Connection connection, Http2FrameListener listener, Http2Settings initialSettings) {
    this(
        connection,
        new DefaultHttp2FrameReader(),
        new DefaultHttp2FrameWriter(),
        listener,
        initialSettings);
  }

  public Http2ConnectionHandler(
      Http2Connection connection,
      Http2FrameReader frameReader,
      Http2FrameWriter frameWriter,
      Http2FrameListener listener,
      Http2Settings initialSettings) {
    this.initialSettings = initialSettings;
    encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter);
    decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader, listener);
  }

  public Http2ConnectionHandler(
      Http2ConnectionDecoder decoder,
      Http2ConnectionEncoder encoder,
      Http2Settings initialSettings) {
    this.initialSettings = initialSettings;
    this.decoder = checkNotNull(decoder, "decoder");
    this.encoder = checkNotNull(encoder, "encoder");
    if (encoder.connection() != decoder.connection()) {
      throw new IllegalArgumentException(
          "Encoder and Decoder do not share the same connection object");
    }
  }

  public Http2Connection connection() {
    return encoder.connection();
  }

  public Http2ConnectionDecoder decoder() {
    return decoder;
  }

  public Http2ConnectionEncoder encoder() {
    return encoder;
  }

  private boolean prefaceSent() {
    return byteDecoder != null && byteDecoder.prefaceSent();
  }

  /**
   * Handles the client-side (cleartext) upgrade from HTTP to HTTP/2. Reserves local stream 1 for
   * the HTTP/2 response.
   */
  public void onHttpClientUpgrade() throws Http2Exception {
    if (connection().isServer()) {
      throw connectionError(PROTOCOL_ERROR, "Client-side HTTP upgrade requested for a server");
    }
    if (prefaceSent() || decoder.prefaceReceived()) {
      throw connectionError(
          PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received");
    }

    // Create a local stream used for the HTTP cleartext upgrade.
    connection().local().createStream(HTTP_UPGRADE_STREAM_ID, true);
  }

  /**
   * Handles the server-side (cleartext) upgrade from HTTP to HTTP/2.
   *
   * @param settings the settings for the remote endpoint.
   */
  public void onHttpServerUpgrade(Http2Settings settings) throws Http2Exception {
    if (!connection().isServer()) {
      throw connectionError(PROTOCOL_ERROR, "Server-side HTTP upgrade requested for a client");
    }
    if (prefaceSent() || decoder.prefaceReceived()) {
      throw connectionError(
          PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received");
    }

    // Apply the settings but no ACK is necessary.
    encoder.remoteSettings(settings);

    // Create a stream in the half-closed state.
    connection().remote().createStream(HTTP_UPGRADE_STREAM_ID, true);
  }

  @Override
  public void flush(ChannelHandlerContext ctx) throws Http2Exception {
    // Trigger pending writes in the remote flow controller.
    connection().remote().flowController().writePendingBytes();
    try {
      super.flush(ctx);
    } catch (Throwable t) {
      throw new Http2Exception(INTERNAL_ERROR, "Error flushing", t);
    }
  }

  private abstract class BaseDecoder {
    public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
        throws Exception;

    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {}

    public void channelActive(ChannelHandlerContext ctx) throws Exception {}

    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
      try {
        final Http2Connection connection = connection();
        // Check if there are streams to avoid the overhead of creating the ChannelFuture.
        if (connection.numActiveStreams() > 0) {
          final ChannelFuture future = ctx.newSucceededFuture();
          connection.forEachActiveStream(
              new Http2StreamVisitor() {
                @Override
                public boolean visit(Http2Stream stream) throws Http2Exception {
                  closeStream(stream, future);
                  return true;
                }
              });
        }
      } finally {
        try {
          encoder().close();
        } finally {
          decoder().close();
        }
      }
    }

    /** Determine if the HTTP/2 connection preface been sent. */
    public boolean prefaceSent() {
      return true;
    }
  }

  private final class PrefaceDecoder extends BaseDecoder {
    private ByteBuf clientPrefaceString;
    private boolean prefaceSent;

    public PrefaceDecoder(ChannelHandlerContext ctx) {
      clientPrefaceString = clientPrefaceString(encoder.connection());
      // This handler was just added to the context. In case it was handled after
      // the connection became active, send the connection preface now.
      sendPreface(ctx);
    }

    @Override
    public boolean prefaceSent() {
      return prefaceSent;
    }

    @Override
    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
      try {
        if (readClientPrefaceString(in) && verifyFirstFrameIsSettings(in)) {
          // After the preface is read, it is time to hand over control to the post initialized
          // decoder.
          byteDecoder = new FrameDecoder();
          byteDecoder.decode(ctx, in, out);
        }
      } catch (Throwable e) {
        onException(ctx, e);
      }
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
      // The channel just became active - send the connection preface to the remote endpoint.
      sendPreface(ctx);
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
      cleanup();
      super.channelInactive(ctx);
    }

    /** Releases the {@code clientPrefaceString}. Any active streams will be left in the open. */
    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
      cleanup();
    }

    /** Releases the {@code clientPrefaceString}. Any active streams will be left in the open. */
    private void cleanup() {
      if (clientPrefaceString != null) {
        clientPrefaceString.release();
        clientPrefaceString = null;
      }
    }

    /**
     * Decodes the client connection preface string from the input buffer.
     *
     * @return {@code true} if processing of the client preface string is complete. Since client
     *     preface strings can only be received by servers, returns true immediately for client
     *     endpoints.
     */
    private boolean readClientPrefaceString(ByteBuf in) throws Http2Exception {
      if (clientPrefaceString == null) {
        return true;
      }

      int prefaceRemaining = clientPrefaceString.readableBytes();
      int bytesRead = min(in.readableBytes(), prefaceRemaining);

      // If the input so far doesn't match the preface, break the connection.
      if (bytesRead == 0
          || !ByteBufUtil.equals(
              in,
              in.readerIndex(),
              clientPrefaceString,
              clientPrefaceString.readerIndex(),
              bytesRead)) {
        String receivedBytes =
            hexDump(
                in, in.readerIndex(), min(in.readableBytes(), clientPrefaceString.readableBytes()));
        throw connectionError(
            PROTOCOL_ERROR,
            "HTTP/2 client preface string missing or corrupt. " + "Hex dump for received bytes: %s",
            receivedBytes);
      }
      in.skipBytes(bytesRead);
      clientPrefaceString.skipBytes(bytesRead);

      if (!clientPrefaceString.isReadable()) {
        // Entire preface has been read.
        clientPrefaceString.release();
        clientPrefaceString = null;
        return true;
      }
      return false;
    }

    /**
     * Peeks at that the next frame in the buffer and verifies that it is a {@code SETTINGS} frame.
     *
     * @param in the inbound buffer.
     * @return {@code} true if the next frame is a {@code SETTINGS} frame, {@code false} if more
     *     data is required before we can determine the next frame type.
     * @throws Http2Exception thrown if the next frame is NOT a {@code SETTINGS} frame.
     */
    private boolean verifyFirstFrameIsSettings(ByteBuf in) throws Http2Exception {
      if (in.readableBytes() < 4) {
        // Need more data before we can see the frame type for the first frame.
        return false;
      }

      byte frameType = in.getByte(in.readerIndex() + 3);
      if (frameType != SETTINGS) {
        throw connectionError(
            PROTOCOL_ERROR,
            "First received frame was not SETTINGS. " + "Hex dump for first 4 bytes: %s",
            hexDump(in, in.readerIndex(), 4));
      }
      return true;
    }

    /**
     * Sends the HTTP/2 connection preface upon establishment of the connection, if not already
     * sent.
     */
    private void sendPreface(ChannelHandlerContext ctx) {
      if (prefaceSent || !ctx.channel().isActive()) {
        return;
      }

      prefaceSent = true;

      if (!connection().isServer()) {
        // Clients must send the preface string as the first bytes on the connection.
        ctx.write(connectionPrefaceBuf()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
      }

      // Both client and server must send their initial settings.
      encoder
          .writeSettings(ctx, initialSettings(), ctx.newPromise())
          .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
    }
  }

  private final class FrameDecoder extends BaseDecoder {
    @Override
    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
      try {
        decoder.decodeFrame(ctx, in, out);
      } catch (Throwable e) {
        onException(ctx, e);
      }
    }
  }

  @Override
  public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
    // Initialize the encoder and decoder.
    encoder.lifecycleManager(this);
    decoder.lifecycleManager(this);
    byteDecoder = new PrefaceDecoder(ctx);
  }

  @Override
  protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
    if (byteDecoder != null) {
      byteDecoder.handlerRemoved(ctx);
      byteDecoder = null;
    }
  }

  @Override
  public void channelActive(ChannelHandlerContext ctx) throws Exception {
    if (byteDecoder == null) {
      byteDecoder = new PrefaceDecoder(ctx);
    }
    byteDecoder.channelActive(ctx);
    super.channelActive(ctx);
  }

  @Override
  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
    if (byteDecoder != null) {
      byteDecoder.channelInactive(ctx);
      super.channelInactive(ctx);
      byteDecoder = null;
    }
  }

  @Override
  protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
    byteDecoder.decode(ctx, in, out);
  }

  @Override
  public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
    // Avoid NotYetConnectedException
    if (!ctx.channel().isActive()) {
      ctx.close(promise);
      return;
    }

    ChannelFuture future = goAway(ctx, null);
    ctx.flush();

    // If there are no active streams, close immediately after the send is complete.
    // Otherwise wait until all streams are inactive.
    if (isGracefulShutdownComplete()) {
      future.addListener(new ClosingChannelFutureListener(ctx, promise));
    } else {
      closeListener = new ClosingChannelFutureListener(ctx, promise);
    }
  }

  @Override
  public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
    // Trigger flush after read on the assumption that flush is cheap if there is nothing to write
    // and that
    // for flow-control the read may release window that causes data to be written that can now be
    // flushed.
    flush(ctx);
  }

  /**
   * Handles {@link Http2Exception} objects that were thrown from other handlers. Ignores all other
   * exceptions.
   */
  @Override
  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    if (getEmbeddedHttp2Exception(cause) != null) {
      // Some exception in the causality chain is an Http2Exception - handle it.
      onException(ctx, cause);
    } else {
      super.exceptionCaught(ctx, cause);
    }
  }

  /**
   * Closes the local side of the given stream. If this causes the stream to be closed, adds a hook
   * to close the channel after the given future completes.
   *
   * @param stream the stream to be half closed.
   * @param future If closing, the future after which to close the channel.
   */
  @Override
  public void closeStreamLocal(Http2Stream stream, ChannelFuture future) {
    switch (stream.state()) {
      case HALF_CLOSED_LOCAL:
      case OPEN:
        stream.closeLocalSide();
        break;
      default:
        closeStream(stream, future);
        break;
    }
  }

  /**
   * Closes the remote side of the given stream. If this causes the stream to be closed, adds a hook
   * to close the channel after the given future completes.
   *
   * @param stream the stream to be half closed.
   * @param future If closing, the future after which to close the channel.
   */
  @Override
  public void closeStreamRemote(Http2Stream stream, ChannelFuture future) {
    switch (stream.state()) {
      case HALF_CLOSED_REMOTE:
      case OPEN:
        stream.closeRemoteSide();
        break;
      default:
        closeStream(stream, future);
        break;
    }
  }

  @Override
  public void closeStream(final Http2Stream stream, ChannelFuture future) {
    stream.close();

    if (future.isDone()) {
      checkCloseConnection(future);
    } else {
      future.addListener(
          new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
              checkCloseConnection(future);
            }
          });
    }
  }

  /** Central handler for all exceptions caught during HTTP/2 processing. */
  @Override
  public void onException(ChannelHandlerContext ctx, Throwable cause) {
    Http2Exception embedded = getEmbeddedHttp2Exception(cause);
    if (isStreamError(embedded)) {
      onStreamError(ctx, cause, (StreamException) embedded);
    } else if (embedded instanceof CompositeStreamException) {
      CompositeStreamException compositException = (CompositeStreamException) embedded;
      for (StreamException streamException : compositException) {
        onStreamError(ctx, cause, streamException);
      }
    } else {
      onConnectionError(ctx, cause, embedded);
    }
    ctx.flush();
  }

  /**
   * Called by the graceful shutdown logic to determine when it is safe to close the connection.
   * Returns {@code true} if the graceful shutdown has completed and the connection can be safely
   * closed. This implementation just guarantees that there are no active streams. Subclasses may
   * override to provide additional checks.
   */
  protected boolean isGracefulShutdownComplete() {
    return connection().numActiveStreams() == 0;
  }

  /**
   * Handler for a connection error. Sends a GO_AWAY frame to the remote endpoint. Once all streams
   * are closed, the connection is shut down.
   *
   * @param ctx the channel context
   * @param cause the exception that was caught
   * @param http2Ex the {@link Http2Exception} that is embedded in the causality chain. This may be
   *     {@code null} if it's an unknown exception.
   */
  protected void onConnectionError(
      ChannelHandlerContext ctx, Throwable cause, Http2Exception http2Ex) {
    if (http2Ex == null) {
      http2Ex = new Http2Exception(INTERNAL_ERROR, cause.getMessage(), cause);
    }
    goAway(ctx, http2Ex).addListener(new ClosingChannelFutureListener(ctx, ctx.newPromise()));
  }

  /**
   * Handler for a stream error. Sends a {@code RST_STREAM} frame to the remote endpoint and closes
   * the stream.
   *
   * @param ctx the channel context
   * @param cause the exception that was caught
   * @param http2Ex the {@link StreamException} that is embedded in the causality chain.
   */
  protected void onStreamError(
      ChannelHandlerContext ctx, Throwable cause, StreamException http2Ex) {
    resetStream(ctx, http2Ex.streamId(), http2Ex.error().code(), ctx.newPromise());
  }

  protected Http2FrameWriter frameWriter() {
    return encoder().frameWriter();
  }

  @Override
  public ChannelFuture resetStream(
      final ChannelHandlerContext ctx, int streamId, long errorCode, final ChannelPromise promise) {
    final Http2Stream stream = connection().stream(streamId);
    if (stream == null || stream.isResetSent()) {
      // Don't write a RST_STREAM frame if we are not aware of the stream, or if we have already
      // written one.
      return promise.setSuccess();
    }

    ChannelFuture future = frameWriter().writeRstStream(ctx, streamId, errorCode, promise);

    // Synchronously set the resetSent flag to prevent any subsequent calls
    // from resulting in multiple reset frames being sent.
    stream.resetSent();

    future.addListener(
        new ChannelFutureListener() {
          @Override
          public void operationComplete(ChannelFuture future) throws Exception {
            if (future.isSuccess()) {
              closeStream(stream, promise);
            } else {
              // The connection will be closed and so no need to change the resetSent flag to false.
              onConnectionError(ctx, future.cause(), null);
            }
          }
        });

    return future;
  }

  @Override
  public ChannelFuture goAway(
      final ChannelHandlerContext ctx,
      final int lastStreamId,
      final long errorCode,
      final ByteBuf debugData,
      ChannelPromise promise) {
    try {
      final Http2Connection connection = connection();
      if (connection.goAwaySent() && lastStreamId > connection.remote().lastStreamKnownByPeer()) {
        throw connectionError(
            PROTOCOL_ERROR,
            "Last stream identifier must not increase between "
                + "sending multiple GOAWAY frames (was '%d', is '%d').",
            connection.remote().lastStreamKnownByPeer(),
            lastStreamId);
      }
      connection.goAwaySent(lastStreamId, errorCode, debugData);

      // Need to retain before we write the buffer because if we do it after the refCnt could
      // already be 0 and
      // result in an IllegalRefCountException.
      debugData.retain();
      ChannelFuture future =
          frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);

      if (future.isDone()) {
        processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future);
      } else {
        future.addListener(
            new ChannelFutureListener() {
              @Override
              public void operationComplete(ChannelFuture future) throws Exception {
                processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future);
              }
            });
      }

      return future;
    } catch (
        Throwable
            cause) { // Make sure to catch Throwable because we are doing a retain() in this method.
      debugData.release();
      return promise.setFailure(cause);
    }
  }

  /**
   * Closes the connection if the graceful shutdown process has completed.
   *
   * @param future Represents the status that will be passed to the {@link #closeListener}.
   */
  private void checkCloseConnection(ChannelFuture future) {
    // If this connection is closing and the graceful shutdown has completed, close the connection
    // once this operation completes.
    if (closeListener != null && isGracefulShutdownComplete()) {
      ChannelFutureListener closeListener = Http2ConnectionHandler.this.closeListener;
      // This method could be called multiple times
      // and we don't want to notify the closeListener multiple times.
      Http2ConnectionHandler.this.closeListener = null;
      try {
        closeListener.operationComplete(future);
      } catch (Exception e) {
        throw new IllegalStateException("Close listener threw an unexpected exception", e);
      }
    }
  }

  /** Gets the initial settings to be sent to the remote endpoint. */
  private Http2Settings initialSettings() {
    return initialSettings != null ? initialSettings : decoder.localSettings();
  }

  /**
   * Close the remote endpoint with with a {@code GO_AWAY} frame. Does <strong>not</strong> flush
   * immediately, this is the responsibility of the caller.
   */
  private ChannelFuture goAway(ChannelHandlerContext ctx, Http2Exception cause) {
    long errorCode = cause != null ? cause.error().code() : NO_ERROR.code();
    ByteBuf debugData = Http2CodecUtil.toByteBuf(ctx, cause);
    int lastKnownStream = connection().remote().lastStreamCreated();
    return goAway(ctx, lastKnownStream, errorCode, debugData, ctx.newPromise());
  }

  /**
   * Returns the client preface string if this is a client connection, otherwise returns {@code
   * null}.
   */
  private static ByteBuf clientPrefaceString(Http2Connection connection) {
    return connection.isServer() ? connectionPrefaceBuf() : null;
  }

  private static void processGoAwayWriteResult(
      final ChannelHandlerContext ctx,
      final int lastStreamId,
      final long errorCode,
      final ByteBuf debugData,
      ChannelFuture future) {
    try {
      if (future.isSuccess()) {
        if (errorCode != NO_ERROR.code()) {
          if (logger.isDebugEnabled()) {
            logger.debug(
                format(
                    "Sent GOAWAY: lastStreamId '%d', errorCode '%d', "
                        + "debugData '%s'. Forcing shutdown of the connection.",
                    lastStreamId, errorCode, debugData.toString(UTF_8)),
                future.cause());
          }
          ctx.close();
        }
      } else {
        if (logger.isErrorEnabled()) {
          logger.error(
              format(
                  "Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', "
                      + "debugData '%s'. Forcing shutdown of the connection.",
                  lastStreamId, errorCode, debugData.toString(UTF_8)),
              future.cause());
        }
        ctx.close();
      }
    } finally {
      // We're done with the debug data now.
      debugData.release();
    }
  }

  /** Closes the channel when the future completes. */
  private static final class ClosingChannelFutureListener implements ChannelFutureListener {
    private final ChannelHandlerContext ctx;
    private final ChannelPromise promise;

    ClosingChannelFutureListener(ChannelHandlerContext ctx, ChannelPromise promise) {
      this.ctx = ctx;
      this.promise = promise;
    }

    @Override
    public void operationComplete(ChannelFuture sentGoAwayFuture) throws Exception {
      ctx.close(promise);
    }
  }
}
@RunWith(Parameterized.class)
public class SocketSslClientRenegotiateTest extends AbstractSocketTest {

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(SocketSslClientRenegotiateTest.class);
  private static final File CERT_FILE;
  private static final File KEY_FILE;

  static {
    SelfSignedCertificate ssc;
    try {
      ssc = new SelfSignedCertificate();
    } catch (CertificateException e) {
      throw new Error(e);
    }
    CERT_FILE = ssc.certificate();
    KEY_FILE = ssc.privateKey();
  }

  @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}")
  public static Collection<Object[]> data() throws Exception {
    List<SslContext> serverContexts = new ArrayList<SslContext>();
    List<SslContext> clientContexts = new ArrayList<SslContext>();
    clientContexts.add(new JdkSslClientContext(CERT_FILE));

    boolean hasOpenSsl = OpenSsl.isAvailable();
    if (hasOpenSsl) {
      OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
      context.setRejectRemoteInitiatedRenegotiation(true);
      serverContexts.add(context);
    } else {
      logger.warn(
          "OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
    }

    List<Object[]> params = new ArrayList<Object[]>();
    for (SslContext sc : serverContexts) {
      for (SslContext cc : clientContexts) {
        for (int i = 0; i < 32; i++) {
          params.add(new Object[] {sc, cc});
        }
      }
    }

    return params;
  }

  private final SslContext serverCtx;
  private final SslContext clientCtx;

  private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
  private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();

  private volatile Channel clientChannel;
  private volatile Channel serverChannel;

  private volatile SslHandler clientSslHandler;
  private volatile SslHandler serverSslHandler;

  private final TestHandler clientHandler = new TestHandler(clientException);

  private final TestHandler serverHandler = new TestHandler(serverException);

  public SocketSslClientRenegotiateTest(SslContext serverCtx, SslContext clientCtx) {
    this.serverCtx = serverCtx;
    this.clientCtx = clientCtx;
  }

  @Test(timeout = 30000)
  public void testSslRenegotiationRejected() throws Throwable {
    Assume.assumeTrue(OpenSsl.isAvailable());
    run();
  }

  public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb) throws Throwable {
    reset();

    sb.childHandler(
        new ChannelInitializer<Channel>() {
          @Override
          @SuppressWarnings("deprecation")
          public void initChannel(Channel sch) throws Exception {
            serverChannel = sch;
            serverSslHandler = serverCtx.newHandler(sch.alloc());

            sch.pipeline().addLast("ssl", serverSslHandler);
            sch.pipeline().addLast("handler", serverHandler);
          }
        });

    cb.handler(
        new ChannelInitializer<Channel>() {
          @Override
          @SuppressWarnings("deprecation")
          public void initChannel(Channel sch) throws Exception {
            clientChannel = sch;
            clientSslHandler = clientCtx.newHandler(sch.alloc());

            sch.pipeline().addLast("ssl", clientSslHandler);
            sch.pipeline().addLast("handler", clientHandler);
          }
        });

    Channel sc = sb.bind().sync().channel();
    cb.connect().sync();

    Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
    clientHandshakeFuture.sync();

    String renegotiation = "SSL_RSA_WITH_RC4_128_SHA";
    clientSslHandler.engine().setEnabledCipherSuites(new String[] {renegotiation});
    clientSslHandler.renegotiate().await();
    serverChannel.close().awaitUninterruptibly();
    clientChannel.close().awaitUninterruptibly();
    sc.close().awaitUninterruptibly();
    try {
      if (serverException.get() != null) {
        throw serverException.get();
      }
      fail();
    } catch (DecoderException e) {
      assertTrue(e.getCause() instanceof SSLHandshakeException);
    }
    if (clientException.get() != null) {
      throw clientException.get();
    }
  }

  private void reset() {
    clientException.set(null);
    serverException.set(null);
    clientHandler.handshakeCounter = 0;
    serverHandler.handshakeCounter = 0;
    clientChannel = null;
    serverChannel = null;

    clientSslHandler = null;
    serverSslHandler = null;
  }

  @Sharable
  private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {

    protected final AtomicReference<Throwable> exception;
    private int handshakeCounter;

    TestHandler(AtomicReference<Throwable> exception) {
      this.exception = exception;
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
      ctx.flush();
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
      exception.compareAndSet(null, cause);
      ctx.close();
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
      if (evt instanceof SslHandshakeCompletionEvent) {
        SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
        if (handshakeCounter == 0) {
          handshakeCounter++;
          if (handshakeEvt.cause() != null) {
            logger.warn("Handshake failed:", handshakeEvt.cause());
          }
          assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
        } else {
          if (ctx.channel().parent() == null) {
            assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
          }
        }
      }
    }

    @Override
    public void messageReceived(ChannelHandlerContext ctx, ByteBuf in) throws Exception {}
  }
}
Пример #25
0
/**
 * {@link io.netty.channel.sctp.SctpChannel} implementation which use non-blocking mode and allows
 * to read / write {@link SctpMessage}s to the underlying {@link SctpChannel}.
 *
 * <p>Be aware that not all operations systems support SCTP. Please refer to the documentation of
 * your operation system, to understand what you need to do to use it. Also this feature is only
 * supported on Java 7+.
 */
public class NioSctpChannel extends AbstractNioMessageChannel
    implements io.netty.channel.sctp.SctpChannel {
  private static final ChannelMetadata METADATA = new ChannelMetadata(false);

  private static final InternalLogger logger =
      InternalLoggerFactory.getInstance(NioSctpChannel.class);

  private final SctpChannelConfig config;

  private final NotificationHandler<?> notificationHandler;

  private static SctpChannel newSctpChannel() {
    try {
      return SctpChannel.open();
    } catch (IOException e) {
      throw new ChannelException("Failed to open a sctp channel.", e);
    }
  }

  /** Create a new instance */
  public NioSctpChannel() {
    this(newSctpChannel());
  }

  /** Create a new instance using {@link SctpChannel} */
  public NioSctpChannel(SctpChannel sctpChannel) {
    this(null, sctpChannel);
  }

  /**
   * Create a new instance
   *
   * @param parent the {@link Channel} which is the parent of this {@link NioSctpChannel} or {@code
   *     null}.
   * @param sctpChannel the underlying {@link SctpChannel}
   */
  public NioSctpChannel(Channel parent, SctpChannel sctpChannel) {
    super(parent, sctpChannel, SelectionKey.OP_READ);
    try {
      sctpChannel.configureBlocking(false);
      config = new NioSctpChannelConfig(this, sctpChannel);
      notificationHandler = new SctpNotificationHandler(this);
    } catch (IOException e) {
      try {
        sctpChannel.close();
      } catch (IOException e2) {
        if (logger.isWarnEnabled()) {
          logger.warn("Failed to close a partially initialized sctp channel.", e2);
        }
      }

      throw new ChannelException("Failed to enter non-blocking mode.", e);
    }
  }

  @Override
  public InetSocketAddress localAddress() {
    return (InetSocketAddress) super.localAddress();
  }

  @Override
  public InetSocketAddress remoteAddress() {
    return (InetSocketAddress) super.remoteAddress();
  }

  @Override
  public SctpServerChannel parent() {
    return (SctpServerChannel) super.parent();
  }

  @Override
  public ChannelMetadata metadata() {
    return METADATA;
  }

  @Override
  public Association association() {
    try {
      return javaChannel().association();
    } catch (IOException ignored) {
      return null;
    }
  }

  @Override
  public Set<InetSocketAddress> allLocalAddresses() {
    try {
      final Set<SocketAddress> allLocalAddresses = javaChannel().getAllLocalAddresses();
      final Set<InetSocketAddress> addresses =
          new LinkedHashSet<InetSocketAddress>(allLocalAddresses.size());
      for (SocketAddress socketAddress : allLocalAddresses) {
        addresses.add((InetSocketAddress) socketAddress);
      }
      return addresses;
    } catch (Throwable ignored) {
      return Collections.emptySet();
    }
  }

  @Override
  public SctpChannelConfig config() {
    return config;
  }

  @Override
  public Set<InetSocketAddress> allRemoteAddresses() {
    try {
      final Set<SocketAddress> allLocalAddresses = javaChannel().getRemoteAddresses();
      final Set<InetSocketAddress> addresses =
          new HashSet<InetSocketAddress>(allLocalAddresses.size());
      for (SocketAddress socketAddress : allLocalAddresses) {
        addresses.add((InetSocketAddress) socketAddress);
      }
      return addresses;
    } catch (Throwable ignored) {
      return Collections.emptySet();
    }
  }

  @Override
  protected SctpChannel javaChannel() {
    return (SctpChannel) super.javaChannel();
  }

  @Override
  public boolean isActive() {
    SctpChannel ch = javaChannel();
    return ch.isOpen() && association() != null;
  }

  @Override
  protected SocketAddress localAddress0() {
    try {
      Iterator<SocketAddress> i = javaChannel().getAllLocalAddresses().iterator();
      if (i.hasNext()) {
        return i.next();
      }
    } catch (IOException e) {
      // ignore
    }
    return null;
  }

  @Override
  protected SocketAddress remoteAddress0() {
    try {
      Iterator<SocketAddress> i = javaChannel().getRemoteAddresses().iterator();
      if (i.hasNext()) {
        return i.next();
      }
    } catch (IOException e) {
      // ignore
    }
    return null;
  }

  @Override
  protected void doBind(SocketAddress localAddress) throws Exception {
    javaChannel().bind(localAddress);
  }

  @Override
  protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress)
      throws Exception {
    if (localAddress != null) {
      javaChannel().bind(localAddress);
    }

    boolean success = false;
    try {
      boolean connected = javaChannel().connect(remoteAddress);
      if (!connected) {
        selectionKey().interestOps(SelectionKey.OP_CONNECT);
      }
      success = true;
      return connected;
    } finally {
      if (!success) {
        doClose();
      }
    }
  }

  @Override
  protected void doFinishConnect() throws Exception {
    if (!javaChannel().finishConnect()) {
      throw new Error();
    }
  }

  @Override
  protected void doDisconnect() throws Exception {
    doClose();
  }

  @Override
  protected void doClose() throws Exception {
    javaChannel().close();
  }

  @Override
  protected int doReadMessages(List<Object> buf) throws Exception {
    SctpChannel ch = javaChannel();

    RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
    ByteBuf buffer = allocHandle.allocate(config().getAllocator());
    boolean free = true;
    try {
      ByteBuffer data = buffer.internalNioBuffer(buffer.writerIndex(), buffer.writableBytes());
      int pos = data.position();

      MessageInfo messageInfo = ch.receive(data, null, notificationHandler);
      if (messageInfo == null) {
        return 0;
      }
      buf.add(
          new SctpMessage(
              messageInfo, buffer.writerIndex(buffer.writerIndex() + data.position() - pos)));
      free = false;
      return 1;
    } catch (Throwable cause) {
      PlatformDependent.throwException(cause);
      return -1;
    } finally {
      int bytesRead = buffer.readableBytes();
      allocHandle.record(bytesRead);
      if (free) {
        buffer.release();
      }
    }
  }

  @Override
  protected boolean doWriteMessage(Object msg, ChannelOutboundBuffer in) throws Exception {
    SctpMessage packet = (SctpMessage) msg;
    ByteBuf data = packet.content();
    int dataLen = data.readableBytes();
    if (dataLen == 0) {
      return true;
    }

    ByteBufAllocator alloc = alloc();
    boolean needsCopy = data.nioBufferCount() != 1;
    if (!needsCopy) {
      if (!data.isDirect() && alloc.isDirectBufferPooled()) {
        needsCopy = true;
      }
    }
    ByteBuffer nioData;
    if (!needsCopy) {
      nioData = data.nioBuffer();
    } else {
      data = alloc.directBuffer(dataLen).writeBytes(data);
      nioData = data.nioBuffer();
    }
    final MessageInfo mi =
        MessageInfo.createOutgoing(association(), null, packet.streamIdentifier());
    mi.payloadProtocolID(packet.protocolIdentifier());
    mi.streamNumber(packet.streamIdentifier());
    mi.unordered(packet.isUnordered());

    final int writtenBytes = javaChannel().send(nioData, mi);
    return writtenBytes > 0;
  }

  @Override
  protected final Object filterOutboundMessage(Object msg) throws Exception {
    if (msg instanceof SctpMessage) {
      SctpMessage m = (SctpMessage) msg;
      ByteBuf buf = m.content();
      if (buf.isDirect() && buf.nioBufferCount() == 1) {
        return m;
      }

      return new SctpMessage(
          m.protocolIdentifier(), m.streamIdentifier(), m.isUnordered(), newDirectBuffer(m, buf));
    }

    throw new UnsupportedOperationException(
        "unsupported message type: "
            + StringUtil.simpleClassName(msg)
            + " (expected: "
            + StringUtil.simpleClassName(SctpMessage.class));
  }

  @Override
  public ChannelFuture bindAddress(InetAddress localAddress) {
    return bindAddress(localAddress, newPromise());
  }

  @Override
  public ChannelFuture bindAddress(final InetAddress localAddress, final ChannelPromise promise) {
    if (eventLoop().inEventLoop()) {
      try {
        javaChannel().bindAddress(localAddress);
        promise.setSuccess();
      } catch (Throwable t) {
        promise.setFailure(t);
      }
    } else {
      eventLoop()
          .execute(
              new Runnable() {
                @Override
                public void run() {
                  bindAddress(localAddress, promise);
                }
              });
    }
    return promise;
  }

  @Override
  public ChannelFuture unbindAddress(InetAddress localAddress) {
    return unbindAddress(localAddress, newPromise());
  }

  @Override
  public ChannelFuture unbindAddress(final InetAddress localAddress, final ChannelPromise promise) {
    if (eventLoop().inEventLoop()) {
      try {
        javaChannel().unbindAddress(localAddress);
        promise.setSuccess();
      } catch (Throwable t) {
        promise.setFailure(t);
      }
    } else {
      eventLoop()
          .execute(
              new Runnable() {
                @Override
                public void run() {
                  unbindAddress(localAddress, promise);
                }
              });
    }
    return promise;
  }

  private final class NioSctpChannelConfig extends DefaultSctpChannelConfig {
    private NioSctpChannelConfig(NioSctpChannel channel, SctpChannel javaChannel) {
      super(channel, javaChannel);
    }

    @Override
    protected void autoReadCleared() {
      setReadPending(false);
    }
  }
}