private void writeHeaders(
      StompHeaderAccessor headers, Message<byte[]> message, DataOutputStream output)
      throws IOException {

    Map<String, List<String>> stompHeaders = headers.toStompHeaderMap();
    if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) {
      logger.trace("Encoded heartbeat");
    } else if (logger.isDebugEnabled()) {
      logger.debug("Encoded STOMP command=" + headers.getCommand() + " headers=" + stompHeaders);
    }
    for (Entry<String, List<String>> entry : stompHeaders.entrySet()) {
      byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers);
      for (String value : entry.getValue()) {
        output.write(key);
        output.write(COLON);
        output.write(getUtf8BytesEscapingIfNecessary(value, headers));
        output.write(LF);
      }
    }
    if ((headers.getCommand() == StompCommand.SEND)
        || (headers.getCommand() == StompCommand.MESSAGE)
        || (headers.getCommand() == StompCommand.ERROR)) {

      output.write("content-length:".getBytes(UTF8_CHARSET));
      output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET));
      output.write(LF);
    }
  }
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage textMessage)
        throws Exception {

      ByteBuffer payload = ByteBuffer.wrap(textMessage.getPayload().getBytes(UTF_8));
      List<Message<byte[]>> messages = this.decoder.decode(payload);

      for (Message message : messages) {
        StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);

        if (StompCommand.CONNECTED.equals(headers.getCommand())) {

          this.stompMessageHandler.afterConnected(session, headers);
        } else if (StompCommand.MESSAGE.equals(headers.getCommand())) {
          this.stompMessageHandler.handleMessage(message);
        } else if (StompCommand.RECEIPT.equals(headers.getCommand())) {
          this.stompMessageHandler.handleReceipt(headers.getReceiptId());
        } else if (StompCommand.ERROR.equals(headers.getCommand())) {
          this.stompMessageHandler.handleError(message);
        } else if (StompCommand.ERROR.equals(headers.getCommand())) {
          this.stompMessageHandler.afterDisconnected();
        } else {
          LOGGER.debug("Unhandled message " + message);
        }
      }
    }
 private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) {
   if (headers.getCommand() != StompCommand.CONNECT
       && headers.getCommand() != StompCommand.CONNECTED) {
     return escape(input).getBytes(UTF8_CHARSET);
   } else {
     return input.getBytes(UTF8_CHARSET);
   }
 }
  /**
   * Encodes the given payload and headers into a {@code byte[]}.
   *
   * @param headers the headers
   * @param payload the payload
   * @return the encoded message
   */
  public byte[] encode(Map<String, Object> headers, byte[] payload) {
    Assert.notNull(headers, "'headers' is required");
    Assert.notNull(payload, "'payload' is required");
    try {
      ByteArrayOutputStream baos = new ByteArrayOutputStream(128 + payload.length);
      DataOutputStream output = new DataOutputStream(baos);

      if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) {
        logger.trace("Encoded heartbeat");
        output.write(StompDecoder.HEARTBEAT_PAYLOAD);
      } else {
        StompCommand command = StompHeaderAccessor.getCommand(headers);
        Assert.notNull(command, "Missing STOMP command: " + headers);
        output.write(command.toString().getBytes(StompDecoder.UTF8_CHARSET));
        output.write(LF);
        writeHeaders(command, headers, payload, output);
        output.write(LF);
        writeBody(payload, output);
        output.write((byte) 0);
      }

      return baos.toByteArray();
    } catch (IOException e) {
      throw new StompConversionException(
          "Failed to encode STOMP frame, headers=" + headers + ".", e);
    }
  }
  @Test
  public void handleMessageFromClient() {

    TextMessage textMessage =
        StompTextMessageBuilder.create(StompCommand.CONNECT)
            .headers(
                "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000")
            .build();

    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    verify(this.channel).send(this.messageCaptor.capture());
    Message<?> actual = this.messageCaptor.getValue();
    assertNotNull(actual);

    StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual);
    assertEquals(StompCommand.CONNECT, headers.getCommand());
    assertEquals("s1", headers.getSessionId());
    assertEquals("joe", headers.getUser().getName());
    assertEquals("guest", headers.getLogin());
    assertEquals("PROTECTED", headers.getPasscode());
    assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
    assertEquals(new HashSet<>(Arrays.asList("1.1", "1.0")), headers.getAcceptVersion());

    assertEquals(0, this.session.getSentMessages().size());
  }
  @Test
  public void handleMessageToClientConnectAck() {

    StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
    connectHeaders.setHeartbeat(10000, 10000);
    connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1");
    Message<?> connectMessage =
        MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build();

    SimpMessageHeaderAccessor connectAckHeaders =
        SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
    connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
    Message<byte[]> connectAckMessage =
        MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();

    this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);

    verifyNoMoreInteractions(this.channel);

    // Check CONNECTED reply

    assertEquals(1, this.session.getSentMessages().size());
    TextMessage textMessage = (TextMessage) this.session.getSentMessages().get(0);
    Message<?> message =
        new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes()));
    StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);

    assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
    assertEquals("1.1", replyHeaders.getVersion());
    assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
    assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
  }
 private void writeCommand(StompHeaderAccessor headers, DataOutputStream output)
     throws IOException {
   output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET));
   output.write(LF);
 }
Example #8
0
  /** Decode a single STOMP frame from the given {@code buffer} into a {@link Message}. */
  private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) {

    Message<byte[]> decodedMessage = null;
    skipLeadingEol(buffer);
    buffer.mark();

    String command = readCommand(buffer);
    if (command.length() > 0) {

      StompHeaderAccessor headerAccessor = null;
      byte[] payload = null;

      if (buffer.remaining() > 0) {
        StompCommand stompCommand = StompCommand.valueOf(command);
        headerAccessor = StompHeaderAccessor.create(stompCommand);
        initHeaders(headerAccessor);

        readHeaders(buffer, headerAccessor);
        payload = readPayload(buffer, headerAccessor);
      }

      if (payload != null) {
        if ((payload.length > 0) && (!headerAccessor.getCommand().isBodyAllowed())) {
          throw new StompConversionException(
              headerAccessor.getCommand()
                  + " shouldn't have a payload: length="
                  + payload.length
                  + ", headers="
                  + headers);
        }
        headerAccessor.updateSimpMessageHeadersFromStompHeaders();
        headerAccessor.setLeaveMutable(true);
        decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
        if (logger.isDebugEnabled()) {
          logger.debug("Decoded " + decodedMessage);
        }
      } else {
        if (logger.isTraceEnabled()) {
          logger.trace("Received incomplete frame. Resetting buffer.");
        }
        if (headers != null && headerAccessor != null) {
          String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
          @SuppressWarnings("unchecked")
          MultiValueMap<String, String> map =
              (MultiValueMap<String, String>) headerAccessor.getHeader(name);
          if (map != null) {
            headers.putAll(map);
          }
        }
        buffer.reset();
      }
    } else {
      if (logger.isTraceEnabled()) {
        logger.trace("Decoded heartbeat");
      }
      StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
      initHeaders(headerAccessor);
      headerAccessor.setLeaveMutable(true);
      decodedMessage =
          MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
    }
    return decodedMessage;
  }