Example #1
0
  @Override
  protected void onFullCloseMessage(
      final WebSocketChannel channel, final BufferedBinaryMessage message) {
    final Pooled<ByteBuffer[]> pooled = message.getData();
    final ByteBuffer singleBuffer = toBuffer(pooled.getResource());
    final ByteBuffer toSend = singleBuffer.duplicate();

    session
        .getContainer()
        .invokeEndpointMethod(
            executor,
            new Runnable() {
              @Override
              public void run() {
                WebSockets.sendClose(toSend, channel, null);
                try {
                  if (singleBuffer.remaining() > 2) {
                    final int code = singleBuffer.getShort();
                    session.close(
                        new javax.websocket.CloseReason(
                            javax.websocket.CloseReason.CloseCodes.getCloseCode(code),
                            new UTF8Output(singleBuffer).extract()));
                  } else {
                    session.close();
                  }
                } catch (IOException e) {
                  invokeOnError(e);
                } finally {
                  pooled.free();
                }
              }
            });
  }
 private void dealWithFullBuffer(StreamSourceFrameChannel channel) {
   if (!current.getResource().hasRemaining()) {
     current.getResource().flip();
     data.add(current);
     current = channel.getWebSocketChannel().getBufferPool().allocate();
   }
 }
        public Pooled<ByteBuffer> getBuffer(boolean firstBuffer) throws IOException {
          Pooled<ByteBuffer> pooled = allocate(Protocol.MESSAGE_DATA);
          ByteBuffer buffer = pooled.getResource();

          // Reserve room for the transmit data which is 4 bytes
          buffer.limit(buffer.limit() - 4);

          buffer.put(firstBuffer ? Protocol.MSG_FLAG_NEW : 0); // flags
          // header size plus window size
          int windowPlusHeader = maximumWindow + 8;
          if (buffer.remaining() > windowPlusHeader) {
            // never try to write more than the maximum window size
            buffer.limit(windowPlusHeader);
          }
          return pooled;
        }
 private void refuseService(final int channelId, final String reason) {
   Pooled<ByteBuffer> pooledReply = connection.allocate();
   boolean ok = false;
   try {
     ByteBuffer replyBuffer = pooledReply.getResource();
     replyBuffer.clear();
     replyBuffer.put(Protocol.SERVICE_ERROR);
     replyBuffer.putInt(channelId);
     replyBuffer.put(reason.getBytes(Protocol.UTF_8));
     replyBuffer.flip();
     ok = true;
     // send takes ownership of the buffer
     connection.send(pooledReply);
   } finally {
     if (!ok) pooledReply.free();
   }
 }
 public PortForwardOpenListener(
     ClientConnection masterPortForwardConnection,
     final String urlPath,
     final int targetPort,
     final AtomicInteger requestId,
     final Pool<ByteBuffer> pool,
     final OptionMap undertowOptions) {
   this.masterPortForwardConnection = masterPortForwardConnection;
   this.urlPath = urlPath;
   this.targetPort = targetPort;
   this.requestId = requestId;
   this.undertowOptions = undertowOptions;
   this.bufferPool = new XnioByteBufferPool(pool);
   Pooled<ByteBuffer> buf = pool.allocate();
   this.bufferSize = buf.getResource().remaining();
   buf.free();
 }
 public void readBlocking(StreamSourceFrameChannel channel) throws IOException {
   if (current == null) {
     current = channel.getWebSocketChannel().getBufferPool().allocate();
   }
   for (; ; ) {
     int res = channel.read(current.getResource());
     if (res == -1) {
       complete = true;
       return;
     } else if (res == 0) {
       channel.awaitReadable();
     }
     checkMaxSize(channel, res);
     if (bufferFullMessage) {
       dealWithFullBuffer(channel);
     } else if (!current.getResource().hasRemaining()) {
       return;
     }
   }
 }
 void sendCapRequest(final String serverName) {
   client.trace("Client sending capabilities request");
   // Prepare the request message body
   final Pooled<ByteBuffer> pooledSendBuffer = connection.allocate();
   boolean ok = false;
   try {
     final ByteBuffer sendBuffer = pooledSendBuffer.getResource();
     sendBuffer.put(Protocol.CAPABILITIES);
     ProtocolUtils.writeByte(sendBuffer, Protocol.CAP_VERSION, 0);
     sendBuffer.flip();
     connection.setReadListener(new Capabilities(serverName));
     connection.send(pooledSendBuffer);
     ok = true;
     // all set
     return;
   } finally {
     if (!ok) {
       pooledSendBuffer.free();
     }
   }
 }
 void sendCapRequest(final String remoteServerName) {
   client.trace("Client sending capabilities request");
   // Prepare the request message body
   final Pooled<ByteBuffer> pooledSendBuffer = connection.allocate();
   boolean ok = false;
   try {
     final ByteBuffer sendBuffer = pooledSendBuffer.getResource();
     sendBuffer.put(Protocol.CAPABILITIES);
     ProtocolUtils.writeByte(sendBuffer, Protocol.CAP_VERSION, Protocol.VERSION);
     final String localEndpointName = connectionProviderContext.getEndpoint().getName();
     if (localEndpointName != null) {
       ProtocolUtils.writeString(sendBuffer, Protocol.CAP_ENDPOINT_NAME, localEndpointName);
     }
     ProtocolUtils.writeEmpty(sendBuffer, Protocol.CAP_MESSAGE_CLOSE);
     ProtocolUtils.writeString(
         sendBuffer, Protocol.CAP_VERSION_STRING, Version.getVersionString());
     ProtocolUtils.writeInt(
         sendBuffer,
         Protocol.CAP_CHANNELS_IN,
         optionMap.get(
             RemotingOptions.MAX_INBOUND_CHANNELS, RemotingOptions.DEFAULT_MAX_INBOUND_CHANNELS));
     ProtocolUtils.writeInt(
         sendBuffer,
         Protocol.CAP_CHANNELS_OUT,
         optionMap.get(
             RemotingOptions.MAX_OUTBOUND_CHANNELS,
             RemotingOptions.DEFAULT_MAX_OUTBOUND_CHANNELS));
     sendBuffer.flip();
     connection.setReadListener(new Capabilities(remoteServerName, uri), true);
     connection.send(pooledSendBuffer);
     ok = true;
     // all set
     return;
   } finally {
     if (!ok) {
       pooledSendBuffer.free();
     }
   }
 }
  public void renegotiateBufferRequest(HttpServerExchange exchange, SslClientAuthMode newAuthMode)
      throws IOException {
    int maxSize =
        exchange
            .getConnection()
            .getUndertowOptions()
            .get(UndertowOptions.MAX_BUFFERED_REQUEST_SIZE, 16384);
    if (maxSize <= 0) {
      throw new SSLPeerUnverifiedException("");
    }

    // first we need to read the request
    boolean requestResetRequired = false;
    StreamSourceChannel requestChannel = Connectors.getExistingRequestChannel(exchange);
    if (requestChannel == null) {
      requestChannel = exchange.getRequestChannel();
      requestResetRequired = true;
    }

    Pooled<ByteBuffer> pooled = exchange.getConnection().getBufferPool().allocate();
    boolean free = true; // if the pooled buffer should be freed
    int usedBuffers = 0;
    Pooled<ByteBuffer>[] poolArray = null;
    final int bufferSize = pooled.getResource().remaining();
    int allowedBuffers = ((maxSize + bufferSize - 1) / bufferSize);
    poolArray = new Pooled[allowedBuffers];
    poolArray[usedBuffers++] = pooled;
    try {
      int res;
      do {
        final ByteBuffer buf = pooled.getResource();
        res = Channels.readBlocking(requestChannel, buf);
        if (!buf.hasRemaining()) {
          if (usedBuffers == allowedBuffers) {
            throw new SSLPeerUnverifiedException("");
          } else {
            buf.flip();
            pooled = exchange.getConnection().getBufferPool().allocate();
            poolArray[usedBuffers++] = pooled;
          }
        }
      } while (res != -1);
      free = false;
      pooled.getResource().flip();
      Connectors.ungetRequestBytes(exchange, poolArray);
      renegotiateNoRequest(exchange, newAuthMode);
    } finally {
      if (free) {
        for (Pooled<ByteBuffer> buf : poolArray) {
          if (buf != null) {
            buf.free();
          }
        }
      }
      if (requestResetRequired) {
        exchange.requestChannel = null;
      }
    }
  }
Example #10
0
 public void accept(final Pooled<ByteBuffer> pooledBuffer, final boolean eof)
     throws IOException {
   try {
     final ByteBuffer buffer = pooledBuffer.getResource();
     final ConnectedMessageChannel messageChannel =
         channel.getRemoteConnection().getChannel();
     if (eof) {
       // EOF flag (sync close)
       buffer.put(7, (byte) (buffer.get(7) | Protocol.MSG_FLAG_EOF));
       log.tracef("Sending message (with EOF) (%s) to %s", buffer, messageChannel);
     }
     if (cancelled) {
       buffer.put(7, (byte) (buffer.get(7) | Protocol.MSG_FLAG_CANCELLED));
       buffer.limit(8); // discard everything in the buffer
       log.trace("Message includes cancel flag");
     }
     synchronized (OutboundMessage.this) {
       int msgSize = buffer.remaining();
       window -= msgSize;
       while (window < msgSize) {
         try {
           log.trace("Message window is closed, waiting");
           OutboundMessage.this.wait();
         } catch (InterruptedException e) {
           Thread.currentThread().interrupt();
           throw new InterruptedIOException("Interrupted on write");
         }
       }
       log.trace("Message window is open, proceeding with send");
     }
     Channels.sendBlocking(messageChannel, buffer);
   } finally {
     pooledBuffer.free();
     if (eof) {
       channel.free(OutboundMessage.this);
     }
   }
 }
  public Pooled<ByteBuffer[]> getData() {
    if (current == null) {
      return new ImmediatePooled<ByteBuffer[]>(new ByteBuffer[0]);
    }
    if (data.isEmpty()) {
      final Pooled<ByteBuffer> current = this.current;
      current.getResource().flip();
      this.current = null;
      final ByteBuffer[] data = new ByteBuffer[] {current.getResource()};
      return new PooledByteBufferArray(
          Collections.<Pooled<ByteBuffer>>singletonList(current), data);
    }
    current.getResource().flip();
    data.add(current);
    current = null;
    ByteBuffer[] ret = new ByteBuffer[data.size()];
    for (int i = 0; i < data.size(); ++i) {
      ret[i] = data.get(i).getResource();
    }
    List<Pooled<ByteBuffer>> data = this.data;
    this.data = new ArrayList<Pooled<ByteBuffer>>();

    return new PooledByteBufferArray(data, ret);
  }
Example #12
0
  @Override
  protected void onFullPongMessage(
      WebSocketChannel webSocketChannel, BufferedBinaryMessage bufferedBinaryMessage) {
    final HandlerWrapper handler = getHandler(FrameType.PONG);
    if (handler != null) {
      final Pooled<ByteBuffer[]> pooled = bufferedBinaryMessage.getData();
      final PongMessage message = DefaultPongMessage.create(toBuffer(pooled.getResource()));

      session
          .getContainer()
          .invokeEndpointMethod(
              executor,
              new Runnable() {
                @Override
                public void run() {
                  try {
                    ((MessageHandler.Whole) handler.getHandler()).onMessage(message);
                  } finally {
                    pooled.free();
                  }
                }
              });
    }
  }
Example #13
0
 public void truncateWrites() throws IOException {
   log.trace("close");
   int oldVal = this.state;
   if (allAreClear(oldVal, MASK_STATE)) {
     try {
       next.truncateWrites();
     } finally {
       if (pooledBuffer != null) {
         pooledBuffer.free();
         pooledBuffer = null;
       }
     }
     return;
   }
   this.state = oldVal & ~MASK_STATE | FLAG_SHUTDOWN | STATE_BODY;
   throw new TruncatedResponseException();
 }
Example #14
0
  /**
   * Initiate a low-copy transfer between two stream channels. The pool should be a direct buffer
   * pool for best performance.
   *
   * @param source the source channel
   * @param sink the target channel
   * @param sourceListener the source listener to set and call when the transfer is complete, or
   *     {@code null} to clear the listener at that time
   * @param sinkListener the target listener to set and call when the transfer is complete, or
   *     {@code null} to clear the listener at that time
   * @param readExceptionHandler the read exception handler to call if an error occurs during a read
   *     operation
   * @param writeExceptionHandler the write exception handler to call if an error occurs during a
   *     write operation
   * @param pool the pool from which the transfer buffer should be allocated
   */
  public static <I extends StreamSourceChannel, O extends StreamSinkChannel> void initiateTransfer(
      final I source,
      final O sink,
      final ChannelListener<? super I> sourceListener,
      final ChannelListener<? super O> sinkListener,
      final ChannelExceptionHandler<? super I> readExceptionHandler,
      final ChannelExceptionHandler<? super O> writeExceptionHandler,
      Pool<ByteBuffer> pool) {
    if (pool == null) {
      throw UndertowMessages.MESSAGES.argumentCannotBeNull("pool");
    }
    final Pooled<ByteBuffer> allocated = pool.allocate();
    boolean free = true;
    try {
      final ByteBuffer buffer = allocated.getResource();
      long read;
      for (; ; ) {
        try {
          read = source.read(buffer);
          buffer.flip();
        } catch (IOException e) {
          ChannelListeners.invokeChannelExceptionHandler(source, readExceptionHandler, e);
          return;
        }
        if (read == 0 && !buffer.hasRemaining()) {
          break;
        }
        if (read == -1 && !buffer.hasRemaining()) {
          done(source, sink, sourceListener, sinkListener);
          return;
        }
        while (buffer.hasRemaining()) {
          final int res;
          try {
            res = sink.write(buffer);
          } catch (IOException e) {
            ChannelListeners.invokeChannelExceptionHandler(sink, writeExceptionHandler, e);
            return;
          }
          if (res == 0) {
            break;
          }
        }
        if (buffer.hasRemaining()) {
          break;
        }
        buffer.clear();
      }
      Pooled<ByteBuffer> current = null;
      if (buffer.hasRemaining()) {
        current = allocated;
        free = false;
      }

      final TransferListener<I, O> listener =
          new TransferListener<I, O>(
              pool,
              current,
              source,
              sink,
              sourceListener,
              sinkListener,
              writeExceptionHandler,
              readExceptionHandler,
              read == -1);
      sink.getWriteSetter().set(listener);
      source.getReadSetter().set(listener);
      // we resume both reads and writes, as we want to keep trying to fill the buffer
      if (current == null || buffer.capacity() != buffer.remaining()) {
        // we don't resume if the buffer is 100% full
        source.resumeReads();
      }
      if (current != null) {
        // we don't resume writes if we have nothing to write
        sink.resumeWrites();
      }
    } finally {
      if (free) {
        allocated.free();
      }
    }
  }
Example #15
0
    public void handleEvent(final Channel channel) {
      if (done) {
        if (channel instanceof StreamSinkChannel) {
          ((StreamSinkChannel) channel).suspendWrites();
        } else if (channel instanceof StreamSourceChannel) {
          ((StreamSourceChannel) channel).suspendReads();
        }
        return;
      }
      boolean noWrite = false;
      if (pooledBuffer == null) {
        pooledBuffer = pool.allocate();
        noWrite = true;
      } else if (channel instanceof StreamSourceChannel) {
        noWrite = true; // attempt a read first, as this is a read notification
        pooledBuffer.getResource().compact();
      }

      final ByteBuffer buffer = pooledBuffer.getResource();
      try {
        long read;

        for (; ; ) {
          boolean writeFailed = false;
          // always attempt to write first if we have the buffer
          if (!noWrite) {
            while (buffer.hasRemaining()) {
              final int res;
              try {
                res = sink.write(buffer);
              } catch (IOException e) {
                pooledBuffer.free();
                pooledBuffer = null;
                done = true;
                ChannelListeners.invokeChannelExceptionHandler(sink, writeExceptionHandler, e);
                return;
              }
              if (res == 0) {
                writeFailed = true;
                break;
              }
            }
            if (sourceDone && !buffer.hasRemaining()) {
              done = true;
              done(source, sink, sourceListener, sinkListener);
              return;
            }
            buffer.compact();
          }
          noWrite = false;

          if (buffer.hasRemaining() && !sourceDone) {
            try {
              read = source.read(buffer);
              buffer.flip();
            } catch (IOException e) {
              pooledBuffer.free();
              pooledBuffer = null;
              done = true;
              ChannelListeners.invokeChannelExceptionHandler(source, readExceptionHandler, e);
              return;
            }
            if (read == 0) {
              break;
            } else if (read == -1) {
              sourceDone = true;
              if (!buffer.hasRemaining()) {
                done = true;
                done(source, sink, sourceListener, sinkListener);
                return;
              }
            }
          } else {
            buffer.flip();
            if (writeFailed) {
              break;
            }
          }
        }
        // suspend writes if there is nothing to write
        if (!buffer.hasRemaining()) {
          sink.suspendWrites();
        } else if (!sink.isWriteResumed()) {
          sink.resumeWrites();
        }
        // suspend reads if there is nothing to read
        if (buffer.remaining() == buffer.capacity()) {
          source.suspendReads();
        } else if (!source.isReadResumed()) {
          source.resumeReads();
        }
      } finally {
        if (pooledBuffer != null && !buffer.hasRemaining()) {
          pooledBuffer.free();
          pooledBuffer = null;
        }
      }
    }
 public void handleEvent(final ConnectedMessageChannel channel) {
   final Pooled<ByteBuffer> pooledReceiveBuffer = connection.allocate();
   try {
     final ByteBuffer receiveBuffer = pooledReceiveBuffer.getResource();
     synchronized (connection.getLock()) {
       int res;
       try {
         res = channel.receive(receiveBuffer);
       } catch (IOException e) {
         connection.handleException(e);
         return;
       }
       if (res == -1) {
         connection.handleException(client.abruptClose(connection));
         return;
       }
       if (res == 0) {
         return;
       }
     }
     client.tracef("Received %s", receiveBuffer);
     receiveBuffer.flip();
     final byte msgType = receiveBuffer.get();
     switch (msgType) {
       case Protocol.CONNECTION_ALIVE:
         {
           client.trace("Client received connection alive");
           connection.sendAliveResponse();
           return;
         }
       case Protocol.CONNECTION_ALIVE_ACK:
         {
           client.trace("Client received connection alive ack");
           return;
         }
       case Protocol.CONNECTION_CLOSE:
         {
           client.trace("Client received connection close request");
           connection.handlePreAuthCloseRequest();
           return;
         }
       case Protocol.STARTTLS:
         {
           client.trace("Client received STARTTLS response");
           Channel c = channel;
           for (; ; ) {
             if (c instanceof SslChannel) {
               try {
                 ((SslChannel) c).startHandshake();
               } catch (IOException e) {
                 connection.handleException(e, false);
                 return;
               }
               sendCapRequest(remoteServerName);
               return;
             } else if (c instanceof WrappedChannel) {
               c = ((WrappedChannel<?>) c).getChannel();
             } else {
               // this should never happen
               connection.handleException(
                   new IOException("Client starting STARTTLS but channel doesn't support SSL"));
               return;
             }
           }
         }
       default:
         {
           client.unknownProtocolId(msgType);
           connection.handleException(client.invalidMessage(connection));
           return;
         }
     }
   } catch (BufferUnderflowException | BufferOverflowException e) {
     connection.handleException(client.invalidMessage(connection));
     return;
   } finally {
     pooledReceiveBuffer.free();
   }
 }
 public void handleEvent(final ConnectedMessageChannel channel) {
   final Pooled<ByteBuffer> pooledBuffer = connection.allocate();
   try {
     final ByteBuffer buffer = pooledBuffer.getResource();
     final int res;
     try {
       res = channel.receive(buffer);
     } catch (IOException e) {
       connection.handleException(e);
       return;
     }
     if (res == 0) {
       return;
     }
     if (res == -1) {
       connection.handleException(client.abruptClose(connection));
       return;
     }
     buffer.flip();
     final byte msgType = buffer.get();
     switch (msgType) {
       case Protocol.CONNECTION_ALIVE:
         {
           client.trace("Client received connection alive");
           return;
         }
       case Protocol.CONNECTION_CLOSE:
         {
           client.trace("Client received connection close request");
           connection.handleIncomingCloseRequest();
           return;
         }
       case Protocol.AUTH_CHALLENGE:
         {
           client.trace("Client received authentication challenge");
           connection
               .getExecutor()
               .execute(
                   new Runnable() {
                     public void run() {
                       final boolean clientComplete = saslClient.isComplete();
                       if (clientComplete) {
                         connection.handleException(
                             new SaslException("Received extra auth message after completion"));
                         return;
                       }
                       final byte[] response;
                       final byte[] challenge = Buffers.take(buffer, buffer.remaining());
                       try {
                         response = saslClient.evaluateChallenge(challenge);
                         if (msgType == Protocol.AUTH_COMPLETE
                             && response != null
                             && response.length > 0) {
                           connection.handleException(
                               new SaslException(
                                   "Received extra auth message after completion"));
                           return;
                         }
                       } catch (Exception e) {
                         client.tracef("Client authentication failed: %s", e);
                         failedMechs.add(saslClient.getMechanismName());
                         sendCapRequest(serverName);
                         return;
                       }
                       client.trace("Client sending authentication response");
                       final Pooled<ByteBuffer> pooled = connection.allocate();
                       final ByteBuffer sendBuffer = pooled.getResource();
                       sendBuffer.put(Protocol.AUTH_RESPONSE);
                       sendBuffer.put(response);
                       sendBuffer.flip();
                       connection.send(pooled);
                       connection.getChannel().resumeReads();
                       return;
                     }
                   });
           connection.getChannel().suspendReads();
           return;
         }
       case Protocol.AUTH_COMPLETE:
         {
           client.trace("Client received authentication complete");
           connection
               .getExecutor()
               .execute(
                   new Runnable() {
                     public void run() {
                       final boolean clientComplete = saslClient.isComplete();
                       final byte[] challenge = Buffers.take(buffer, buffer.remaining());
                       if (!clientComplete)
                         try {
                           final byte[] response = saslClient.evaluateChallenge(challenge);
                           if (response != null && response.length > 0) {
                             connection.handleException(
                                 new SaslException(
                                     "Received extra auth message after completion"));
                             return;
                           }
                           if (!saslClient.isComplete()) {
                             connection.handleException(
                                 new SaslException(
                                     "Client not complete after processing auth complete message"));
                             return;
                           }
                         } catch (SaslException e) {
                           // todo log message
                           failedMechs.add(saslClient.getMechanismName());
                           sendCapRequest(serverName);
                           return;
                         }
                       // auth complete.
                       final ConnectionHandlerFactory connectionHandlerFactory =
                           new ConnectionHandlerFactory() {
                             public ConnectionHandler createInstance(
                                 final ConnectionHandlerContext connectionContext) {
                               // this happens immediately.
                               final RemoteConnectionHandler connectionHandler =
                                   new RemoteConnectionHandler(connectionContext, connection);
                               connection.setReadListener(
                                   new RemoteReadListener(connectionHandler, connection));
                               return connectionHandler;
                             }
                           };
                       connection.getResult().setResult(connectionHandlerFactory);
                       connection.getChannel().resumeReads();
                       return;
                     }
                   });
           connection.getChannel().suspendReads();
           return;
         }
       case Protocol.AUTH_REJECTED:
         {
           client.trace("Client received authentication rejected");
           failedMechs.add(saslClient.getMechanismName());
           sendCapRequest(serverName);
           return;
         }
       default:
         {
           client.unknownProtocolId(msgType);
           connection.handleException(client.invalidMessage(connection));
           return;
         }
     }
   } finally {
     pooledBuffer.free();
   }
 }
 public void handleEvent(final ConnectedMessageChannel channel) {
   final Pooled<ByteBuffer> pooledReceiveBuffer = connection.allocate();
   try {
     final ByteBuffer receiveBuffer = pooledReceiveBuffer.getResource();
     int res = 0;
     try {
       res = channel.receive(receiveBuffer);
     } catch (IOException e) {
       connection.handleException(e);
       return;
     }
     if (res == -1) {
       connection.handleException(client.abruptClose(connection));
       return;
     }
     if (res == 0) {
       return;
     }
     client.tracef("Received %s", receiveBuffer);
     receiveBuffer.flip();
     final byte msgType = receiveBuffer.get();
     switch (msgType) {
       case Protocol.CONNECTION_ALIVE:
         {
           client.trace("Client received connection alive");
           return;
         }
       case Protocol.CONNECTION_CLOSE:
         {
           client.trace("Client received connection close request");
           connection.handleIncomingCloseRequest();
           return;
         }
       case Protocol.STARTTLS:
         {
           client.trace("Client received STARTTLS response");
           try {
             ((SslChannel) channel).startHandshake();
           } catch (IOException e) {
             connection.handleException(e, false);
             return;
           }
           sendCapRequest(serverName);
           return;
         }
       default:
         {
           client.unknownProtocolId(msgType);
           connection.handleException(client.invalidMessage(connection));
           return;
         }
     }
   } catch (BufferUnderflowException e) {
     connection.handleException(client.invalidMessage(connection));
     return;
   } catch (BufferOverflowException e) {
     connection.handleException(client.invalidMessage(connection));
     return;
   } finally {
     pooledReceiveBuffer.free();
   }
 }
    public void handleEvent(final ConnectedMessageChannel channel) {
      final Pooled<ByteBuffer> pooledReceiveBuffer = connection.allocate();
      try {
        final ByteBuffer receiveBuffer = pooledReceiveBuffer.getResource();
        int res = 0;
        try {
          res = channel.receive(receiveBuffer);
        } catch (IOException e) {
          connection.handleException(e);
          return;
        }
        if (res == -1) {
          connection.handleException(client.abruptClose(connection));
          return;
        }
        if (res == 0) {
          return;
        }
        receiveBuffer.flip();
        boolean starttls = false;
        final Set<String> saslMechs = new LinkedHashSet<String>();
        final byte msgType = receiveBuffer.get();
        switch (msgType) {
          case Protocol.CONNECTION_ALIVE:
            {
              client.trace("Client received connection alive");
              return;
            }
          case Protocol.CONNECTION_CLOSE:
            {
              client.trace("Client received connection close request");
              connection.handleIncomingCloseRequest();
              return;
            }
          case Protocol.CAPABILITIES:
            {
              client.trace("Client received capabilities response");
              while (receiveBuffer.hasRemaining()) {
                final byte type = receiveBuffer.get();
                final int len = receiveBuffer.get() & 0xff;
                final ByteBuffer data = Buffers.slice(receiveBuffer, len);
                switch (type) {
                  case Protocol.CAP_VERSION:
                    {
                      final byte version = data.get();
                      client.tracef(
                          "Client received capability: version %d",
                          Integer.valueOf(version & 0xff));
                      // We only support version zero, so knowing the other side's version is not
                      // useful presently
                      break;
                    }
                  case Protocol.CAP_SASL_MECH:
                    {
                      final String mechName = Buffers.getModifiedUtf8(data);
                      client.tracef("Client received capability: SASL mechanism %s", mechName);
                      if (!failedMechs.contains(mechName)
                          && !disallowedMechs.contains(mechName)
                          && (allowedMechs == null || allowedMechs.contains(mechName))) {
                        client.tracef("SASL mechanism %s added to allowed set", mechName);
                        saslMechs.add(mechName);
                      }
                      break;
                    }
                  case Protocol.CAP_STARTTLS:
                    {
                      client.trace("Client received capability: STARTTLS");
                      starttls = true;
                      break;
                    }
                  default:
                    {
                      client.tracef(
                          "Client received unknown capability %02x", Integer.valueOf(type & 0xff));
                      // unknown, skip it for forward compatibility.
                      break;
                    }
                }
              }
              if (starttls) {
                // only initiate starttls if not forbidden by config
                if (optionMap.get(Options.SSL_STARTTLS, true)) {
                  // Prepare the request message body
                  final Pooled<ByteBuffer> pooledSendBuffer = connection.allocate();
                  final ByteBuffer sendBuffer = pooledSendBuffer.getResource();
                  sendBuffer.put(Protocol.STARTTLS);
                  sendBuffer.flip();
                  connection.setReadListener(new StartTls(serverName));
                  connection.send(pooledSendBuffer);
                  // all set
                  return;
                }
              }

              if (saslMechs.isEmpty()) {
                connection.handleException(
                    new SaslException("No more authentication mechanisms to try"));
                return;
              }
              // OK now send our authentication request
              final OptionMap optionMap = connection.getOptionMap();
              final String userName = optionMap.get(RemotingOptions.AUTHORIZE_ID);
              final Map<String, ?> propertyMap =
                  SaslUtils.createPropertyMap(
                      optionMap, Channels.getOption(channel, Options.SECURE, false));
              final SaslClient saslClient;
              try {
                saslClient =
                    AccessController.doPrivileged(
                        new PrivilegedExceptionAction<SaslClient>() {
                          public SaslClient run() throws SaslException {
                            return Sasl.createSaslClient(
                                saslMechs.toArray(new String[saslMechs.size()]),
                                userName,
                                "remote",
                                serverName,
                                propertyMap,
                                callbackHandler);
                          }
                        },
                        accessControlContext);
              } catch (PrivilegedActionException e) {
                final SaslException se = (SaslException) e.getCause();
                connection.handleException(se);
                return;
              }
              final String mechanismName = saslClient.getMechanismName();
              client.tracef("Client initiating authentication using mechanism %s", mechanismName);
              // Prepare the request message body
              final Pooled<ByteBuffer> pooledSendBuffer = connection.allocate();
              final ByteBuffer sendBuffer = pooledSendBuffer.getResource();
              sendBuffer.put(Protocol.AUTH_REQUEST);
              Buffers.putModifiedUtf8(sendBuffer, mechanismName);
              sendBuffer.flip();
              connection.send(pooledSendBuffer);
              connection.setReadListener(new Authentication(saslClient, serverName));
              return;
            }
          default:
            {
              client.unknownProtocolId(msgType);
              connection.handleException(client.invalidMessage(connection));
              return;
            }
        }
      } catch (BufferUnderflowException e) {
        connection.handleException(client.invalidMessage(connection));
        return;
      } catch (BufferOverflowException e) {
        connection.handleException(client.invalidMessage(connection));
        return;
      } finally {
        pooledReceiveBuffer.free();
      }
    }
 public void handleEvent(final ConnectedMessageChannel channel) {
   final Pooled<ByteBuffer> pooledReceiveBuffer = connection.allocate();
   try {
     final ByteBuffer receiveBuffer = pooledReceiveBuffer.getResource();
     int res = 0;
     try {
       res = channel.receive(receiveBuffer);
     } catch (IOException e) {
       connection.handleException(e);
       return;
     }
     if (res == -1) {
       connection.handleException(client.abruptClose(connection));
       return;
     }
     if (res == 0) {
       return;
     }
     client.tracef("Received %s", receiveBuffer);
     receiveBuffer.flip();
     String serverName = channel.getPeerAddress(InetSocketAddress.class).getHostName();
     final byte msgType = receiveBuffer.get();
     switch (msgType) {
       case Protocol.CONNECTION_ALIVE:
         {
           client.trace("Client received connection alive");
           return;
         }
       case Protocol.CONNECTION_CLOSE:
         {
           client.trace("Client received connection close request");
           connection.handleIncomingCloseRequest();
           return;
         }
       case Protocol.GREETING:
         {
           client.trace("Client received greeting");
           while (receiveBuffer.hasRemaining()) {
             final byte type = receiveBuffer.get();
             final int len = receiveBuffer.get() & 0xff;
             final ByteBuffer data = Buffers.slice(receiveBuffer, len);
             switch (type) {
               case Protocol.GRT_SERVER_NAME:
                 {
                   serverName = Buffers.getModifiedUtf8(data);
                   client.tracef("Client received server name: %s", serverName);
                   break;
                 }
               default:
                 {
                   client.tracef(
                       "Client received unknown greeting message %02x",
                       Integer.valueOf(type & 0xff));
                   // unknown, skip it for forward compatibility.
                   break;
                 }
             }
           }
           sendCapRequest(serverName);
           return;
         }
       default:
         {
           client.unknownProtocolId(msgType);
           connection.handleException(client.invalidMessage(connection));
           return;
         }
     }
   } catch (BufferUnderflowException e) {
     connection.handleException(client.invalidMessage(connection));
     return;
   } catch (BufferOverflowException e) {
     connection.handleException(client.invalidMessage(connection));
     return;
   } finally {
     pooledReceiveBuffer.free();
   }
 }
Example #21
0
  public void handleEvent(final StreamSourceChannel channel) {

    Pooled<ByteBuffer> existing = connection.getExtraBytes();

    final Pooled<ByteBuffer> pooled =
        existing == null ? connection.getBufferPool().allocate() : existing;
    final ByteBuffer buffer = pooled.getResource();
    boolean free = true;

    try {
      int res;
      do {
        if (existing == null) {
          buffer.clear();
          try {
            res = channel.read(buffer);
          } catch (IOException e) {
            UndertowLogger.REQUEST_IO_LOGGER.debug("Error reading request", e);
            IoUtils.safeClose(connection);
            return;
          }
        } else {
          res = buffer.remaining();
        }

        if (res == 0) {
          if (!channel.isReadResumed()) {
            channel.getReadSetter().set(this);
            channel.resumeReads();
          }
          return;
        } else if (res == -1) {
          try {
            channel.suspendReads();
            channel.shutdownReads();
            final StreamSinkChannel responseChannel = this.connection.getChannel().getSinkChannel();
            responseChannel.shutdownWrites();
            // will return false if there's a response queued ahead of this one, so we'll set up a
            // listener then
            if (!responseChannel.flush()) {
              responseChannel
                  .getWriteSetter()
                  .set(ChannelListeners.flushingChannelListener(null, null));
              responseChannel.resumeWrites();
            }
          } catch (IOException e) {
            UndertowLogger.REQUEST_IO_LOGGER.debug("Error reading request", e);
            // f**k it, it's all ruined
            IoUtils.safeClose(channel);
            return;
          }
          return;
        }
        if (existing != null) {
          existing = null;
          connection.setExtraBytes(null);
        } else {
          buffer.flip();
        }
        parser.handle(buffer, state, httpServerExchange);
        if (buffer.hasRemaining()) {
          free = false;
          connection.setExtraBytes(pooled);
        }
        int total = read + res;
        read = total;
        if (read > maxRequestSize) {
          UndertowLogger.REQUEST_LOGGER.requestHeaderWasTooLarge(
              connection.getPeerAddress(), maxRequestSize);
          IoUtils.safeClose(connection);
          return;
        }
      } while (!state.isComplete());

      // we remove ourselves as the read listener from the channel;
      // if the http handler doesn't set any then reads will suspend, which is the right thing to do
      channel.getReadSetter().set(null);
      channel.suspendReads();

      final HttpServerExchange httpServerExchange = this.httpServerExchange;
      httpServerExchange.putAttachment(
          UndertowOptions.ATTACHMENT_KEY, connection.getUndertowOptions());
      httpServerExchange.setRequestScheme(connection.getSslSession() != null ? "https" : "http");
      this.httpServerExchange = null;
      HttpTransferEncoding.setupRequest(httpServerExchange);
      HttpHandlers.executeRootHandler(
          connection.getRootHandler(),
          httpServerExchange,
          Thread.currentThread() instanceof XnioExecutor);
    } catch (Exception e) {
      sendBadRequestAndClose(connection.getChannel(), e);
      return;
    } finally {
      if (free) pooled.free();
    }
  }
  public void handleEvent(final ConnectedMessageChannel channel) {
    int res;
    SaslWrapper saslWrapper = connection.getSaslWrapper();
    try {
      Pooled<ByteBuffer> pooled = connection.allocate();
      ByteBuffer buffer = pooled.getResource();
      try {
        for (; ; )
          try {
            res = channel.receive(buffer);
            if (res == -1) {
              log.trace("Received connection end-of-stream");
              try {
                channel.shutdownReads();
              } finally {
                handler.handleConnectionClose();
              }
              return;
            } else if (res == 0) {
              log.trace("No message ready; returning");
              return;
            }
            buffer.flip();
            if (saslWrapper != null) {
              final ByteBuffer source = buffer.duplicate();
              buffer.clear();
              saslWrapper.unwrap(buffer, source);
              buffer.flip();
            }
            final byte protoId = buffer.get();
            try {
              switch (protoId) {
                case Protocol.CONNECTION_ALIVE:
                  {
                    log.trace("Received connection alive");
                    connection.sendAliveResponse();
                    return;
                  }
                case Protocol.CONNECTION_ALIVE_ACK:
                  {
                    log.trace("Received connection alive ack");
                    return;
                  }
                case Protocol.CONNECTION_CLOSE:
                  {
                    log.trace("Received connection close request");
                    handler.receiveCloseRequest();
                    return;
                  }
                case Protocol.CHANNEL_OPEN_REQUEST:
                  {
                    log.trace("Received channel open request");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    int inboundWindow = Integer.MAX_VALUE;
                    int inboundMessages = 0xffff;
                    int outboundWindow = Integer.MAX_VALUE;
                    int outboundMessages = 0xffff;
                    long inboundMessageSize = Long.MAX_VALUE;
                    long outboundMessageSize = Long.MAX_VALUE;
                    // parse out request
                    int b;
                    String serviceType = null;
                    OUT:
                    for (; ; ) {
                      b = buffer.get() & 0xff;
                      switch (b) {
                        case Protocol.O_END:
                          break OUT;
                        case Protocol.O_SERVICE_NAME:
                          {
                            serviceType = ProtocolUtils.readString(buffer);
                            break;
                          }
                        case Protocol.O_MAX_INBOUND_MSG_WINDOW_SIZE:
                          {
                            outboundWindow =
                                Math.min(outboundWindow, ProtocolUtils.readInt(buffer));
                            break;
                          }
                        case Protocol.O_MAX_INBOUND_MSG_COUNT:
                          {
                            outboundMessages =
                                Math.min(outboundMessages, ProtocolUtils.readUnsignedShort(buffer));
                            break;
                          }
                        case Protocol.O_MAX_OUTBOUND_MSG_WINDOW_SIZE:
                          {
                            inboundWindow = Math.min(inboundWindow, ProtocolUtils.readInt(buffer));
                            break;
                          }
                        case Protocol.O_MAX_OUTBOUND_MSG_COUNT:
                          {
                            inboundMessages =
                                Math.min(inboundMessages, ProtocolUtils.readUnsignedShort(buffer));
                            break;
                          }
                        case Protocol.O_MAX_INBOUND_MSG_SIZE:
                          {
                            outboundMessageSize =
                                Math.min(outboundMessageSize, ProtocolUtils.readLong(buffer));
                            break;
                          }
                        case Protocol.O_MAX_OUTBOUND_MSG_SIZE:
                          {
                            inboundMessageSize =
                                Math.min(inboundMessageSize, ProtocolUtils.readLong(buffer));
                            break;
                          }
                        default:
                          {
                            Buffers.skip(buffer, buffer.get() & 0xff);
                            break;
                          }
                      }
                    }
                    if ((channelId & 0x80000000) != 0) {
                      // invalid channel ID, original should have had MSB=1 and thus the complement
                      // should be MSB=0
                      refuseService(channelId, "Invalid channel ID");
                      break;
                    }

                    if (serviceType == null) {
                      // invalid service reply
                      refuseService(channelId, "Missing service name");
                      break;
                    }

                    final RegisteredService registeredService =
                        handler.getConnectionContext().getRegisteredService(serviceType);
                    if (registeredService == null) {
                      refuseService(channelId, "Unknown service name");
                      break;
                    }
                    final OptionMap serviceOptionMap = registeredService.getOptionMap();
                    outboundWindow =
                        Math.min(
                            outboundWindow,
                            serviceOptionMap.get(
                                RemotingOptions.TRANSMIT_WINDOW_SIZE,
                                Protocol.DEFAULT_WINDOW_SIZE));
                    outboundMessages =
                        Math.min(
                            outboundMessages,
                            serviceOptionMap.get(
                                RemotingOptions.MAX_OUTBOUND_MESSAGES,
                                Protocol.DEFAULT_MESSAGE_COUNT));
                    inboundWindow =
                        Math.min(
                            inboundWindow,
                            serviceOptionMap.get(
                                RemotingOptions.RECEIVE_WINDOW_SIZE, Protocol.DEFAULT_WINDOW_SIZE));
                    inboundMessages =
                        Math.min(
                            inboundMessages,
                            serviceOptionMap.get(
                                RemotingOptions.MAX_INBOUND_MESSAGES,
                                Protocol.DEFAULT_MESSAGE_COUNT));
                    outboundMessageSize =
                        Math.min(
                            outboundMessageSize,
                            serviceOptionMap.get(
                                RemotingOptions.MAX_OUTBOUND_MESSAGE_SIZE, Long.MAX_VALUE));
                    inboundMessageSize =
                        Math.min(
                            inboundMessageSize,
                            serviceOptionMap.get(
                                RemotingOptions.MAX_INBOUND_MESSAGE_SIZE, Long.MAX_VALUE));

                    final OpenListener openListener = registeredService.getOpenListener();
                    if (!handler.handleInboundChannelOpen()) {
                      // refuse
                      refuseService(channelId, "Channel refused");
                      break;
                    }
                    boolean ok1 = false;
                    try {
                      // construct the channel
                      RemoteConnectionChannel connectionChannel =
                          new RemoteConnectionChannel(
                              handler,
                              connection,
                              channelId,
                              outboundWindow,
                              inboundWindow,
                              outboundMessages,
                              inboundMessages,
                              outboundMessageSize,
                              inboundMessageSize);
                      RemoteConnectionChannel existing = handler.addChannel(connectionChannel);
                      if (existing != null) {
                        log.tracef("Encountered open request for duplicate %s", existing);
                        // the channel already exists, which means the remote side "forgot" about it
                        // or we somehow missed the close message.
                        // the only safe thing to do is to terminate the existing channel.
                        try {
                          refuseService(channelId, "Duplicate ID");
                        } finally {
                          existing.handleRemoteClose();
                        }
                        break;
                      }

                      // construct reply
                      Pooled<ByteBuffer> pooledReply = connection.allocate();
                      boolean ok2 = false;
                      try {
                        ByteBuffer replyBuffer = pooledReply.getResource();
                        replyBuffer.clear();
                        replyBuffer.put(Protocol.CHANNEL_OPEN_ACK);
                        replyBuffer.putInt(channelId);
                        ProtocolUtils.writeInt(
                            replyBuffer, Protocol.O_MAX_INBOUND_MSG_WINDOW_SIZE, inboundWindow);
                        ProtocolUtils.writeShort(
                            replyBuffer, Protocol.O_MAX_INBOUND_MSG_COUNT, inboundMessages);
                        if (inboundMessageSize != Long.MAX_VALUE) {
                          ProtocolUtils.writeLong(
                              replyBuffer, Protocol.O_MAX_INBOUND_MSG_SIZE, inboundMessageSize);
                        }
                        ProtocolUtils.writeInt(
                            replyBuffer, Protocol.O_MAX_OUTBOUND_MSG_WINDOW_SIZE, outboundWindow);
                        ProtocolUtils.writeShort(
                            replyBuffer, Protocol.O_MAX_OUTBOUND_MSG_COUNT, outboundMessages);
                        if (outboundMessageSize != Long.MAX_VALUE) {
                          ProtocolUtils.writeLong(
                              replyBuffer, Protocol.O_MAX_OUTBOUND_MSG_SIZE, outboundMessageSize);
                        }
                        replyBuffer.put((byte) 0);
                        replyBuffer.flip();
                        ok2 = true;
                        // send takes ownership of the buffer
                        connection.send(pooledReply);
                      } finally {
                        if (!ok2) pooledReply.free();
                      }

                      ok1 = true;

                      // Call the service open listener
                      connection
                          .getExecutor()
                          .execute(SpiUtils.getServiceOpenTask(connectionChannel, openListener));
                      break;
                    } finally {
                      // the inbound channel wasn't open so don't leak the ref count
                      if (!ok1) handler.handleInboundChannelClosed();
                    }
                  }
                case Protocol.MESSAGE_DATA:
                  {
                    log.trace("Received message data");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    RemoteConnectionChannel connectionChannel = handler.getChannel(channelId);
                    if (connectionChannel == null) {
                      // ignore the data
                      log.tracef("Ignoring message data for expired channel");
                      break;
                    }
                    connectionChannel.handleMessageData(pooled);
                    // need a new buffer now
                    pooled = connection.allocate();
                    buffer = pooled.getResource();
                    break;
                  }
                case Protocol.MESSAGE_WINDOW_OPEN:
                  {
                    log.trace("Received message window open");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    RemoteConnectionChannel connectionChannel = handler.getChannel(channelId);
                    if (connectionChannel == null) {
                      // ignore
                      log.tracef("Ignoring window open for expired channel");
                      break;
                    }
                    connectionChannel.handleWindowOpen(pooled);
                    break;
                  }
                case Protocol.MESSAGE_CLOSE:
                  {
                    log.trace("Received message async close");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    RemoteConnectionChannel connectionChannel = handler.getChannel(channelId);
                    if (connectionChannel == null) {
                      break;
                    }
                    connectionChannel.handleAsyncClose(pooled);
                    break;
                  }
                case Protocol.CHANNEL_CLOSED:
                  {
                    log.trace("Received channel closed");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    RemoteConnectionChannel connectionChannel = handler.getChannel(channelId);
                    if (connectionChannel == null) {
                      break;
                    }
                    connectionChannel.handleRemoteClose();
                    break;
                  }
                case Protocol.CHANNEL_SHUTDOWN_WRITE:
                  {
                    log.trace("Received channel shutdown write");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    RemoteConnectionChannel connectionChannel = handler.getChannel(channelId);
                    if (connectionChannel == null) {
                      break;
                    }
                    connectionChannel.handleIncomingWriteShutdown();
                    break;
                  }
                case Protocol.CHANNEL_OPEN_ACK:
                  {
                    log.trace("Received channel open ack");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    if ((channelId & 0x80000000) == 0) {
                      // invalid
                      break;
                    }
                    PendingChannel pendingChannel = handler.removePendingChannel(channelId);
                    if (pendingChannel == null) {
                      // invalid
                      break;
                    }
                    int outboundWindow = pendingChannel.getOutboundWindowSize();
                    int inboundWindow = pendingChannel.getInboundWindowSize();
                    int outboundMessageCount = pendingChannel.getOutboundMessageCount();
                    int inboundMessageCount = pendingChannel.getInboundMessageCount();
                    long outboundMessageSize = pendingChannel.getOutboundMessageSize();
                    long inboundMessageSize = pendingChannel.getInboundMessageSize();
                    OUT:
                    for (; ; ) {
                      switch (buffer.get() & 0xff) {
                        case Protocol.O_MAX_INBOUND_MSG_WINDOW_SIZE:
                          {
                            outboundWindow =
                                Math.min(outboundWindow, ProtocolUtils.readInt(buffer));
                            break;
                          }
                        case Protocol.O_MAX_INBOUND_MSG_COUNT:
                          {
                            outboundMessageCount =
                                Math.min(
                                    outboundMessageCount, ProtocolUtils.readUnsignedShort(buffer));
                            break;
                          }
                        case Protocol.O_MAX_OUTBOUND_MSG_WINDOW_SIZE:
                          {
                            inboundWindow = Math.min(inboundWindow, ProtocolUtils.readInt(buffer));
                            break;
                          }
                        case Protocol.O_MAX_OUTBOUND_MSG_COUNT:
                          {
                            inboundMessageCount =
                                Math.min(
                                    inboundMessageCount, ProtocolUtils.readUnsignedShort(buffer));
                            break;
                          }
                        case Protocol.O_MAX_INBOUND_MSG_SIZE:
                          {
                            outboundMessageSize =
                                Math.min(outboundMessageSize, ProtocolUtils.readLong(buffer));
                            break;
                          }
                        case Protocol.O_MAX_OUTBOUND_MSG_SIZE:
                          {
                            inboundMessageSize =
                                Math.min(inboundMessageSize, ProtocolUtils.readLong(buffer));
                            break;
                          }
                        case Protocol.O_END:
                          {
                            break OUT;
                          }
                        default:
                          {
                            // ignore unknown parameter
                            Buffers.skip(buffer, buffer.get() & 0xff);
                            break;
                          }
                      }
                    }
                    RemoteConnectionChannel newChannel =
                        new RemoteConnectionChannel(
                            handler,
                            connection,
                            channelId,
                            outboundWindow,
                            inboundWindow,
                            outboundMessageCount,
                            inboundMessageCount,
                            outboundMessageSize,
                            inboundMessageSize);
                    handler.putChannel(newChannel);
                    pendingChannel.getResult().setResult(newChannel);
                    break;
                  }
                case Protocol.SERVICE_ERROR:
                  {
                    log.trace("Received service error");
                    int channelId = buffer.getInt() ^ 0x80000000;
                    PendingChannel pendingChannel = handler.removePendingChannel(channelId);
                    if (pendingChannel == null) {
                      // invalid
                      break;
                    }
                    String reason = new String(Buffers.take(buffer), Protocol.UTF_8);
                    pendingChannel.getResult().setException(new IOException(reason));
                    break;
                  }
                default:
                  {
                    log.unknownProtocolId(protoId);
                    break;
                  }
              }
            } catch (BufferUnderflowException e) {
              log.bufferUnderflow(protoId);
            }
          } catch (BufferUnderflowException e) {
            log.bufferUnderflowRaw();
          } finally {
            buffer.clear();
          }
      } finally {
        pooled.free();
      }
    } catch (IOException e) {
      connection.handleException(e);
      handler.handleConnectionClose();
    }
  }
    public void handleEvent(final ConnectedMessageChannel channel) {
      final Pooled<ByteBuffer> pooledBuffer = connection.allocate();
      boolean free = true;
      try {
        final ByteBuffer buffer = pooledBuffer.getResource();
        synchronized (connection.getLock()) {
          final int res;
          try {
            res = channel.receive(buffer);
          } catch (IOException e) {
            connection.handleException(e);
            saslDispose(saslClient);
            return;
          }
          if (res == 0) {
            return;
          }
          if (res == -1) {
            connection.handleException(client.abruptClose(connection));
            saslDispose(saslClient);
            return;
          }
        }
        buffer.flip();
        final byte msgType = buffer.get();
        switch (msgType) {
          case Protocol.CONNECTION_ALIVE:
            {
              client.trace("Client received connection alive");
              connection.sendAliveResponse();
              return;
            }
          case Protocol.CONNECTION_ALIVE_ACK:
            {
              client.trace("Client received connection alive ack");
              return;
            }
          case Protocol.CONNECTION_CLOSE:
            {
              client.trace("Client received connection close request");
              connection.handlePreAuthCloseRequest();
              saslDispose(saslClient);
              return;
            }
          case Protocol.AUTH_CHALLENGE:
            {
              client.trace("Client received authentication challenge");
              channel.suspendReads();
              connection
                  .getExecutor()
                  .execute(
                      () -> {
                        try {
                          final boolean clientComplete = saslClient.isComplete();
                          if (clientComplete) {
                            connection.handleException(
                                new SaslException("Received extra auth message after completion"));
                            return;
                          }
                          final byte[] response;
                          final byte[] challenge = Buffers.take(buffer, buffer.remaining());
                          try {
                            response = saslClient.evaluateChallenge(challenge);
                          } catch (Throwable e) {
                            final String mechanismName = saslClient.getMechanismName();
                            client.debugf(
                                "Client authentication failed for mechanism %s: %s",
                                mechanismName, e);
                            failedMechs.put(mechanismName, e.toString());
                            saslDispose(saslClient);
                            sendCapRequest(serverName);
                            return;
                          }
                          client.trace("Client sending authentication response");
                          final Pooled<ByteBuffer> pooled = connection.allocate();
                          boolean ok = false;
                          try {
                            final ByteBuffer sendBuffer = pooled.getResource();
                            sendBuffer.put(Protocol.AUTH_RESPONSE);
                            sendBuffer.put(response);
                            sendBuffer.flip();
                            connection.send(pooled);
                            ok = true;
                            channel.resumeReads();
                          } finally {
                            if (!ok) pooled.free();
                          }
                          return;
                        } finally {
                          pooledBuffer.free();
                        }
                      });
              free = false;
              return;
            }
          case Protocol.AUTH_COMPLETE:
            {
              client.trace("Client received authentication complete");
              channel.suspendReads();
              connection
                  .getExecutor()
                  .execute(
                      () -> {
                        try {
                          final boolean clientComplete = saslClient.isComplete();
                          final byte[] challenge = Buffers.take(buffer, buffer.remaining());
                          if (!clientComplete)
                            try {
                              final byte[] response = saslClient.evaluateChallenge(challenge);
                              if (response != null && response.length > 0) {
                                connection.handleException(
                                    new SaslException(
                                        "Received extra auth message after completion"));
                                saslDispose(saslClient);
                                return;
                              }
                              if (!saslClient.isComplete()) {
                                connection.handleException(
                                    new SaslException(
                                        "Client not complete after processing auth complete message"));
                                saslDispose(saslClient);
                                return;
                              }
                            } catch (Throwable e) {
                              final String mechanismName = saslClient.getMechanismName();
                              client.debugf(
                                  "Client authentication failed for mechanism %s: %s",
                                  mechanismName, e);
                              failedMechs.put(mechanismName, e.toString());
                              saslDispose(saslClient);
                              sendCapRequest(serverName);
                              return;
                            }
                          final Object qop = saslClient.getNegotiatedProperty(Sasl.QOP);
                          if ("auth-int".equals(qop) || "auth-conf".equals(qop)) {
                            connection.setSaslWrapper(SaslWrapper.create(saslClient));
                          }
                          // auth complete.
                          final ConnectionHandlerFactory connectionHandlerFactory =
                              connectionContext -> {

                                // this happens immediately.
                                final RemoteConnectionHandler connectionHandler =
                                    new RemoteConnectionHandler(
                                        connectionContext,
                                        connection,
                                        maxInboundChannels,
                                        maxOutboundChannels,
                                        remoteEndpointName,
                                        behavior);
                                connection.setReadListener(
                                    new RemoteReadListener(connectionHandler, connection), false);
                                connection
                                    .getRemoteConnectionProvider()
                                    .addConnectionHandler(connectionHandler);
                                return connectionHandler;
                              };
                          connection.getResult().setResult(connectionHandlerFactory);
                          channel.resumeReads();
                          return;
                        } finally {
                          pooledBuffer.free();
                        }
                      });
              free = false;
              return;
            }
          case Protocol.AUTH_REJECTED:
            {
              final String mechanismName = saslClient.getMechanismName();
              client.debugf(
                  "Client received authentication rejected for mechanism %s", mechanismName);
              failedMechs.put(mechanismName, "Server rejected authentication");
              saslDispose(saslClient);
              sendCapRequest(serverName);
              return;
            }
          default:
            {
              client.unknownProtocolId(msgType);
              connection.handleException(client.invalidMessage(connection));
              saslDispose(saslClient);
              return;
            }
        }
      } finally {
        if (free) pooledBuffer.free();
      }
    }
 @Override
 public void done() {
   for (Pooled buffer : buffers) {
     buffer.free();
   }
 }
 @Override
 public void failed(IOException e) {
   buffer.free();
 }
 @Override
 public void discard() {
   for (Pooled<ByteBuffer> item : pooled) {
     item.discard();
   }
 }
  public void read(
      final StreamSourceFrameChannel channel,
      final WebSocketCallback<BufferedBinaryMessage> callback) {
    try {
      for (; ; ) {
        if (current == null) {
          current = channel.getWebSocketChannel().getBufferPool().allocate();
        }
        int res = channel.read(current.getResource());
        if (res == -1) {
          this.complete = true;
          callback.complete(channel.getWebSocketChannel(), this);
          return;
        } else if (res == 0) {
          channel
              .getReadSetter()
              .set(
                  new ChannelListener<StreamSourceFrameChannel>() {
                    @Override
                    public void handleEvent(StreamSourceFrameChannel channel) {
                      try {
                        for (; ; ) {
                          if (current == null) {
                            current = channel.getWebSocketChannel().getBufferPool().allocate();
                          }
                          int res = channel.read(current.getResource());
                          if (res == -1) {
                            complete = true;
                            channel.suspendReads();
                            callback.complete(
                                channel.getWebSocketChannel(), BufferedBinaryMessage.this);
                            return;
                          } else if (res == 0) {
                            return;
                          }

                          checkMaxSize(channel, res);
                          if (bufferFullMessage) {
                            dealWithFullBuffer(channel);
                          } else if (!current.getResource().hasRemaining()) {
                            callback.complete(
                                channel.getWebSocketChannel(), BufferedBinaryMessage.this);
                          } else {
                            handleNewFrame(channel, callback);
                          }
                        }
                      } catch (IOException e) {
                        channel.suspendReads();
                        callback.onError(
                            channel.getWebSocketChannel(), BufferedBinaryMessage.this, e);
                      }
                    }
                  });
          channel.resumeReads();
          return;
        }

        checkMaxSize(channel, res);
        if (bufferFullMessage) {
          dealWithFullBuffer(channel);
        } else if (!current.getResource().hasRemaining()) {
          callback.complete(channel.getWebSocketChannel(), BufferedBinaryMessage.this);
        } else {
          handleNewFrame(channel, callback);
        }
      }
    } catch (IOException e) {
      callback.onError(channel.getWebSocketChannel(), this, e);
    }
  }
 @Override
 public void free() {
   for (Pooled<ByteBuffer> item : pooled) {
     item.free();
   }
 }
Example #29
0
 /**
  * Handles writing out the header data. It can also take a byte buffer of user data, to enable
  * both user data and headers to be written out in a single operation, which has a noticeable
  * performance impact.
  *
  * <p>It is up to the caller to note the current position of this buffer before and after they
  * call this method, and use this to figure out how many bytes (if any) have been written.
  *
  * @param state
  * @param userData
  * @return
  * @throws java.io.IOException
  */
 private int processWrite(int state, final ByteBuffer userData) throws IOException {
   if (state == STATE_START) {
     pooledBuffer = pool.allocate();
   }
   ClientRequest request = this.request;
   ByteBuffer buffer = pooledBuffer.getResource();
   Iterator<HttpString> nameIterator = this.nameIterator;
   Iterator<String> valueIterator = this.valueIterator;
   int charIndex = this.charIndex;
   int length;
   String string = this.string;
   HttpString headerName = this.headerName;
   int res;
   // BUFFER IS FLIPPED COMING IN
   if (state != STATE_START && buffer.hasRemaining()) {
     log.trace("Flushing remaining buffer");
     do {
       res = next.write(buffer);
       if (res == 0) {
         return state;
       }
     } while (buffer.hasRemaining());
   }
   buffer.clear();
   // BUFFER IS NOW EMPTY FOR FILLING
   for (; ; ) {
     switch (state) {
       case STATE_BODY:
         {
           // shouldn't be possible, but might as well do the right thing anyway
           return state;
         }
       case STATE_START:
         {
           log.trace("Starting request");
           // we assume that our buffer has enough space for the initial request line plus one more
           // CR+LF
           assert buffer.remaining() >= 50;
           request.getMethod().appendTo(buffer);
           buffer.put((byte) ' ');
           string = request.getPath();
           length = string.length();
           for (charIndex = 0; charIndex < length; charIndex++) {
             buffer.put((byte) string.charAt(charIndex));
           }
           buffer.put((byte) ' ');
           request.getProtocol().appendTo(buffer);
           buffer.put((byte) '\r').put((byte) '\n');
           HeaderMap headers = request.getRequestHeaders();
           nameIterator = headers.getHeaderNames().iterator();
           if (!nameIterator.hasNext()) {
             log.trace("No request headers");
             buffer.put((byte) '\r').put((byte) '\n');
             buffer.flip();
             while (buffer.hasRemaining()) {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 return STATE_BUF_FLUSH;
               }
             }
             pooledBuffer.free();
             pooledBuffer = null;
             log.trace("Body");
             return STATE_BODY;
           }
           headerName = nameIterator.next();
           charIndex = 0;
           // fall thru
         }
       case STATE_HDR_NAME:
         {
           log.tracef("Processing header '%s'", headerName);
           length = headerName.length();
           while (charIndex < length) {
             if (buffer.hasRemaining()) {
               buffer.put(headerName.byteAt(charIndex++));
             } else {
               log.trace("Buffer flush");
               buffer.flip();
               do {
                 res = next.write(buffer);
                 if (res == 0) {
                   this.string = string;
                   this.headerName = headerName;
                   this.charIndex = charIndex;
                   this.valueIterator = valueIterator;
                   this.nameIterator = nameIterator;
                   log.trace("Continuation");
                   return STATE_HDR_NAME;
                 }
               } while (buffer.hasRemaining());
               buffer.clear();
             }
           }
           // fall thru
         }
       case STATE_HDR_D:
         {
           if (!buffer.hasRemaining()) {
             buffer.flip();
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 this.string = string;
                 this.headerName = headerName;
                 this.charIndex = charIndex;
                 this.valueIterator = valueIterator;
                 this.nameIterator = nameIterator;
                 return STATE_HDR_D;
               }
             } while (buffer.hasRemaining());
             buffer.clear();
           }
           buffer.put((byte) ':');
           // fall thru
         }
       case STATE_HDR_DS:
         {
           if (!buffer.hasRemaining()) {
             buffer.flip();
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 this.string = string;
                 this.headerName = headerName;
                 this.charIndex = charIndex;
                 this.valueIterator = valueIterator;
                 this.nameIterator = nameIterator;
                 return STATE_HDR_DS;
               }
             } while (buffer.hasRemaining());
             buffer.clear();
           }
           buffer.put((byte) ' ');
           if (valueIterator == null) {
             valueIterator = request.getRequestHeaders().get(headerName).iterator();
           }
           assert valueIterator.hasNext();
           string = valueIterator.next();
           charIndex = 0;
           // fall thru
         }
       case STATE_HDR_VAL:
         {
           log.tracef("Processing header value '%s'", string);
           length = string.length();
           while (charIndex < length) {
             if (buffer.hasRemaining()) {
               buffer.put((byte) string.charAt(charIndex++));
             } else {
               buffer.flip();
               do {
                 res = next.write(buffer);
                 if (res == 0) {
                   this.string = string;
                   this.headerName = headerName;
                   this.charIndex = charIndex;
                   this.valueIterator = valueIterator;
                   this.nameIterator = nameIterator;
                   log.trace("Continuation");
                   return STATE_HDR_VAL;
                 }
               } while (buffer.hasRemaining());
               buffer.clear();
             }
           }
           charIndex = 0;
           if (!valueIterator.hasNext()) {
             if (!buffer.hasRemaining()) {
               buffer.flip();
               do {
                 res = next.write(buffer);
                 if (res == 0) {
                   log.trace("Continuation");
                   return STATE_HDR_EOL_CR;
                 }
               } while (buffer.hasRemaining());
               buffer.clear();
             }
             buffer.put((byte) 13); // CR
             if (!buffer.hasRemaining()) {
               buffer.flip();
               do {
                 res = next.write(buffer);
                 if (res == 0) {
                   log.trace("Continuation");
                   return STATE_HDR_EOL_LF;
                 }
               } while (buffer.hasRemaining());
               buffer.clear();
             }
             buffer.put((byte) 10); // LF
             if (nameIterator.hasNext()) {
               headerName = nameIterator.next();
               valueIterator = null;
               state = STATE_HDR_NAME;
               break;
             } else {
               if (!buffer.hasRemaining()) {
                 buffer.flip();
                 do {
                   res = next.write(buffer);
                   if (res == 0) {
                     log.trace("Continuation");
                     return STATE_HDR_FINAL_CR;
                   }
                 } while (buffer.hasRemaining());
                 buffer.clear();
               }
               buffer.put((byte) 13); // CR
               if (!buffer.hasRemaining()) {
                 buffer.flip();
                 do {
                   res = next.write(buffer);
                   if (res == 0) {
                     log.trace("Continuation");
                     return STATE_HDR_FINAL_LF;
                   }
                 } while (buffer.hasRemaining());
                 buffer.clear();
               }
               buffer.put((byte) 10); // LF
               this.nameIterator = null;
               this.valueIterator = null;
               this.string = null;
               buffer.flip();
               // for performance reasons we use a gather write if there is user data
               if (userData == null) {
                 do {
                   res = next.write(buffer);
                   if (res == 0) {
                     log.trace("Continuation");
                     return STATE_BUF_FLUSH;
                   }
                 } while (buffer.hasRemaining());
               } else {
                 ByteBuffer[] b = {buffer, userData};
                 do {
                   long r = next.write(b, 0, b.length);
                   if (r == 0 && buffer.hasRemaining()) {
                     log.trace("Continuation");
                     return STATE_BUF_FLUSH;
                   }
                 } while (buffer.hasRemaining());
               }
               pooledBuffer.free();
               pooledBuffer = null;
               log.trace("Body");
               return STATE_BODY;
             }
             // not reached
           }
           // fall thru
         }
         // Clean-up states
       case STATE_HDR_EOL_CR:
         {
           if (!buffer.hasRemaining()) {
             buffer.flip();
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 return STATE_HDR_EOL_CR;
               }
             } while (buffer.hasRemaining());
             buffer.clear();
           }
           buffer.put((byte) 13); // CR
         }
       case STATE_HDR_EOL_LF:
         {
           if (!buffer.hasRemaining()) {
             buffer.flip();
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 return STATE_HDR_EOL_LF;
               }
             } while (buffer.hasRemaining());
             buffer.clear();
           }
           buffer.put((byte) 10); // LF
           if (valueIterator.hasNext()) {
             state = STATE_HDR_NAME;
             break;
           } else if (nameIterator.hasNext()) {
             headerName = nameIterator.next();
             valueIterator = null;
             state = STATE_HDR_NAME;
             break;
           }
           // fall thru
         }
       case STATE_HDR_FINAL_CR:
         {
           if (!buffer.hasRemaining()) {
             buffer.flip();
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 return STATE_HDR_FINAL_CR;
               }
             } while (buffer.hasRemaining());
             buffer.clear();
           }
           buffer.put((byte) 13); // CR
           // fall thru
         }
       case STATE_HDR_FINAL_LF:
         {
           if (!buffer.hasRemaining()) {
             buffer.flip();
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 return STATE_HDR_FINAL_LF;
               }
             } while (buffer.hasRemaining());
             buffer.clear();
           }
           buffer.put((byte) 10); // LF
           this.nameIterator = null;
           this.valueIterator = null;
           this.string = null;
           buffer.flip();
           // for performance reasons we use a gather write if there is user data
           if (userData == null) {
             do {
               res = next.write(buffer);
               if (res == 0) {
                 log.trace("Continuation");
                 return STATE_BUF_FLUSH;
               }
             } while (buffer.hasRemaining());
           } else {
             ByteBuffer[] b = {buffer, userData};
             do {
               long r = next.write(b, 0, b.length);
               if (r == 0 && buffer.hasRemaining()) {
                 log.trace("Continuation");
                 return STATE_BUF_FLUSH;
               }
             } while (buffer.hasRemaining());
           }
           // fall thru
         }
       case STATE_BUF_FLUSH:
         {
           // buffer was successfully flushed above
           pooledBuffer.free();
           pooledBuffer = null;
           return STATE_BODY;
         }
       default:
         {
           throw new IllegalStateException();
         }
     }
   }
 }
 @Override
 public void done() {
   buffer.free();
 }