@Test
  public void handleMessageToClientConnectAck() {

    StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
    connectHeaders.setHeartbeat(10000, 10000);
    connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1");
    Message<?> connectMessage =
        MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build();

    SimpMessageHeaderAccessor connectAckHeaders =
        SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
    connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
    Message<byte[]> connectAckMessage =
        MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();

    this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);

    verifyNoMoreInteractions(this.channel);

    // Check CONNECTED reply

    assertEquals(1, this.session.getSentMessages().size());
    TextMessage textMessage = (TextMessage) this.session.getSentMessages().get(0);
    Message<?> message =
        new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes()));
    StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);

    assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
    assertEquals("1.1", replyHeaders.getVersion());
    assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
    assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
  }
  private void writeHeaders(
      StompHeaderAccessor headers, Message<byte[]> message, DataOutputStream output)
      throws IOException {

    Map<String, List<String>> stompHeaders = headers.toStompHeaderMap();
    if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) {
      logger.trace("Encoded heartbeat");
    } else if (logger.isDebugEnabled()) {
      logger.debug("Encoded STOMP command=" + headers.getCommand() + " headers=" + stompHeaders);
    }
    for (Entry<String, List<String>> entry : stompHeaders.entrySet()) {
      byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers);
      for (String value : entry.getValue()) {
        output.write(key);
        output.write(COLON);
        output.write(getUtf8BytesEscapingIfNecessary(value, headers));
        output.write(LF);
      }
    }
    if ((headers.getCommand() == StompCommand.SEND)
        || (headers.getCommand() == StompCommand.MESSAGE)
        || (headers.getCommand() == StompCommand.ERROR)) {

      output.write("content-length:".getBytes(UTF8_CHARSET));
      output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET));
      output.write(LF);
    }
  }
 private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) {
   if (headers.getCommand() != StompCommand.CONNECT
       && headers.getCommand() != StompCommand.CONNECTED) {
     return escape(input).getBytes(UTF8_CHARSET);
   } else {
     return input.getBytes(UTF8_CHARSET);
   }
 }
  @Test
  public void handleMessageFromClient() {

    TextMessage textMessage =
        StompTextMessageBuilder.create(StompCommand.CONNECT)
            .headers(
                "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000")
            .build();

    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    verify(this.channel).send(this.messageCaptor.capture());
    Message<?> actual = this.messageCaptor.getValue();
    assertNotNull(actual);

    StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual);
    assertEquals(StompCommand.CONNECT, headers.getCommand());
    assertEquals("s1", headers.getSessionId());
    assertEquals("joe", headers.getUser().getName());
    assertEquals("guest", headers.getLogin());
    assertEquals("PROTECTED", headers.getPasscode());
    assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
    assertEquals(new HashSet<>(Arrays.asList("1.1", "1.0")), headers.getAcceptVersion());

    assertEquals(0, this.session.getSentMessages().size());
  }
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {

      StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
      headers.setAcceptVersion("1.1,1.2");
      headers.setHeartbeat(0, 0);
      Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();

      TextMessage textMessage = new TextMessage(new String(this.encoder.encode(message), UTF_8));
      session.sendMessage(textMessage);
    }
 @SubscribeMapping("/topic/activity")
 @SendTo("/topic/tracker")
 public ActivityDTO sendActivity(
     @Payload ActivityDTO activityDTO,
     StompHeaderAccessor stompHeaderAccessor,
     Principal principal) {
   activityDTO.setUserLogin(SecurityUtils.getCurrentLogin());
   activityDTO.setUserLogin(principal.getName());
   activityDTO.setSessionId(stompHeaderAccessor.getSessionId());
   activityDTO.setIpAddress(stompHeaderAccessor.getSessionAttributes().get(IP_ADDRESS).toString());
   activityDTO.setTime(dateTimeFormatter.print(Calendar.getInstance().getTimeInMillis()));
   log.debug("Sending user tracking data {}", activityDTO);
   return activityDTO;
 }
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage textMessage)
        throws Exception {

      ByteBuffer payload = ByteBuffer.wrap(textMessage.getPayload().getBytes(UTF_8));
      List<Message<byte[]>> messages = this.decoder.decode(payload);

      for (Message message : messages) {
        StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);

        if (StompCommand.CONNECTED.equals(headers.getCommand())) {

          this.stompMessageHandler.afterConnected(session, headers);
        } else if (StompCommand.MESSAGE.equals(headers.getCommand())) {
          this.stompMessageHandler.handleMessage(message);
        } else if (StompCommand.RECEIPT.equals(headers.getCommand())) {
          this.stompMessageHandler.handleReceipt(headers.getReceiptId());
        } else if (StompCommand.ERROR.equals(headers.getCommand())) {
          this.stompMessageHandler.handleError(message);
        } else if (StompCommand.ERROR.equals(headers.getCommand())) {
          this.stompMessageHandler.afterDisconnected();
        } else {
          LOGGER.debug("Unhandled message " + message);
        }
      }
    }
  /**
   * Encodes the given payload and headers into a {@code byte[]}.
   *
   * @param headers the headers
   * @param payload the payload
   * @return the encoded message
   */
  public byte[] encode(Map<String, Object> headers, byte[] payload) {
    Assert.notNull(headers, "'headers' is required");
    Assert.notNull(payload, "'payload' is required");
    try {
      ByteArrayOutputStream baos = new ByteArrayOutputStream(128 + payload.length);
      DataOutputStream output = new DataOutputStream(baos);

      if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) {
        logger.trace("Encoded heartbeat");
        output.write(StompDecoder.HEARTBEAT_PAYLOAD);
      } else {
        StompCommand command = StompHeaderAccessor.getCommand(headers);
        Assert.notNull(command, "Missing STOMP command: " + headers);
        output.write(command.toString().getBytes(StompDecoder.UTF8_CHARSET));
        output.write(LF);
        writeHeaders(command, headers, payload, output);
        output.write(LF);
        writeBody(payload, output);
        output.write((byte) 0);
      }

      return baos.toByteArray();
    } catch (IOException e) {
      throw new StompConversionException(
          "Failed to encode STOMP frame, headers=" + headers + ".", e);
    }
  }
    public WebSocketMessage<?> encode(
        Message<byte[]> message, Class<? extends WebSocketSession> sessionType) {
      StompHeaderAccessor accessor =
          MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
      Assert.notNull(accessor, "No StompHeaderAccessor available");
      byte[] payload = message.getPayload();
      byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload);

      boolean useBinary =
          (payload.length > 0
              && !(SockJsSession.class.isAssignableFrom(sessionType))
              && MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(
                  accessor.getContentType()));

      return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes));
    }
  private byte[] readPayload(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {

    Integer contentLength;
    try {
      contentLength = headerAccessor.getContentLength();
    } catch (NumberFormatException ex) {
      logger.warn("Ignoring invalid content-length: '" + headerAccessor);
      contentLength = null;
    }

    if (contentLength != null && contentLength >= 0) {
      if (buffer.remaining() > contentLength) {
        byte[] payload = new byte[contentLength];
        buffer.get(payload);
        if (buffer.get() != 0) {
          throw new StompConversionException("Frame must be terminated with a null octet");
        }
        return payload;
      } else {
        return null;
      }
    } else {
      ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
      while (buffer.remaining() > 0) {
        byte b = buffer.get();
        if (b == 0) {
          return payload.toByteArray();
        } else {
          payload.write(b);
        }
      }
    }
    return null;
  }
 private void readHeaders(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
   while (true) {
     ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
     while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
       headerStream.write(buffer.get());
     }
     if (headerStream.size() > 0) {
       String header = new String(headerStream.toByteArray(), UTF8_CHARSET);
       int colonIndex = header.indexOf(':');
       if ((colonIndex <= 0) || (colonIndex == header.length() - 1)) {
         if (buffer.remaining() > 0) {
           throw new StompConversionException(
               "Illegal header: '" + header + "'. A header must be of the form <name>:<value>");
         }
       } else {
         String headerName = unescape(header.substring(0, colonIndex));
         String headerValue = unescape(header.substring(colonIndex + 1));
         try {
           headerAccessor.addNativeHeader(headerName, headerValue);
         } catch (InvalidMimeTypeException ex) {
           if (buffer.remaining() > 0) {
             throw ex;
           }
         }
       }
     } else {
       break;
     }
   }
 }
 @Override
 public void handleException(
     StompSession session,
     StompCommand command,
     StompHeaders headers,
     byte[] payload,
     Throwable exception) {
   Message<byte[]> message =
       MessageBuilder.createMessage(
           payload, StompHeaderAccessor.create(command, headers).getMessageHeaders());
   logger.error(
       "The exception for session [" + session + "] on message [" + message + "]", exception);
 }
  @Test
  public void handleMessageToClientConnected() {

    UserSessionRegistry registry = new DefaultUserSessionRegistry();
    this.protocolHandler.setUserSessionRegistry(registry);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
    Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
    this.protocolHandler.handleMessageToClient(this.session, message);

    assertEquals(1, this.session.getSentMessages().size());
    WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
    assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());

    assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe"));
  }
Exemple #14
0
  @Override
  public void onApplicationEvent(ApplicationEvent event) {
    if (event instanceof SessionSubscribeEvent) {

      SessionSubscribeEvent connect = (SessionSubscribeEvent) event;
      StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(connect.getMessage());
      HistoryService service = new HistoryService();
      String dest = headerAccessor.getDestination();
      Long diagramId;
      try {
        diagramId = Long.parseLong(dest.substring(dest.lastIndexOf('/') + 1));
        HistorySession session =
            service.insertSession(
                connect.getUser().getName(), diagramId, headerAccessor.getSubscriptionId());
        HistoryModel model =
            new HistoryModel(
                connect.getUser().getName(),
                session.getDiagram().getName(),
                session.getTimeStart(),
                session.getTimeFinish());
        template.convertAndSend("/topic/diagram/" + diagramId + "/history", model);
      } catch (NumberFormatException ex) {
        return;
      }
    } else if (event instanceof SessionUnsubscribeEvent) {
      SessionUnsubscribeEvent connect = (SessionUnsubscribeEvent) event;
      StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(connect.getMessage());
      HistoryService service = new HistoryService();
      List<HistorySession> sessions =
          service.updateSession(connect.getUser().getName(), headerAccessor.getSubscriptionId());
      for (HistorySession session : sessions) {
        HistoryModel model =
            new HistoryModel(
                connect.getUser().getName(),
                session.getDiagram().getName(),
                session.getTimeStart(),
                session.getTimeFinish());
        template.convertAndSend(
            "/topic/diagram/" + session.getDiagram().getDiagramId() + "/history", model);
      }
    }
  }
  /**
   * Encodes the given STOMP {@code message} into a {@code byte[]}
   *
   * @param message The message to encode
   * @return The encoded message
   */
  public byte[] encode(Message<byte[]> message) {
    try {
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      DataOutputStream output = new DataOutputStream(baos);

      StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);

      if (isHeartbeat(headers)) {
        output.write(message.getPayload());
      } else {
        writeCommand(headers, output);
        writeHeaders(headers, message, output);
        output.write(LF);
        writeBody(message, output);
        output.write((byte) 0);
      }

      return baos.toByteArray();
    } catch (IOException e) {
      throw new StompConversionException("Failed to encode STOMP frame", e);
    }
  }
  private void writeHeaders(
      StompCommand command, Map<String, Object> headers, byte[] payload, DataOutputStream output)
      throws IOException {

    @SuppressWarnings("unchecked")
    Map<String, List<String>> nativeHeaders =
        (Map<String, List<String>>) headers.get(NativeMessageHeaderAccessor.NATIVE_HEADERS);

    if (logger.isTraceEnabled()) {
      logger.trace("Encoding STOMP " + command + ", headers=" + nativeHeaders + ".");
    }

    if (nativeHeaders == null) {
      return;
    }

    boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.CONNECTED);

    for (Entry<String, List<String>> entry : nativeHeaders.entrySet()) {
      byte[] key = encodeHeaderString(entry.getKey(), shouldEscape);
      List<String> values = entry.getValue();
      if (StompHeaderAccessor.STOMP_PASSCODE_HEADER.equals(entry.getKey())) {
        values = Arrays.asList(StompHeaderAccessor.getPasscode(headers));
      }
      for (String value : values) {
        output.write(key);
        output.write(COLON);
        output.write(encodeHeaderString(value, shouldEscape));
        output.write(LF);
      }
    }
    if (command.requiresContentLength()) {
      int contentLength = payload.length;
      output.write("content-length:".getBytes(StompDecoder.UTF8_CHARSET));
      output.write(Integer.toString(contentLength).getBytes(StompDecoder.UTF8_CHARSET));
      output.write(LF);
    }
  }
  /** Decode a single STOMP frame from the given {@code buffer} into a {@link Message}. */
  private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) {

    Message<byte[]> decodedMessage = null;
    skipLeadingEol(buffer);
    buffer.mark();

    String command = readCommand(buffer);
    if (command.length() > 0) {

      StompHeaderAccessor headerAccessor = null;
      byte[] payload = null;

      if (buffer.remaining() > 0) {
        StompCommand stompCommand = StompCommand.valueOf(command);
        headerAccessor = StompHeaderAccessor.create(stompCommand);
        initHeaders(headerAccessor);

        readHeaders(buffer, headerAccessor);
        payload = readPayload(buffer, headerAccessor);
      }

      if (payload != null) {
        if ((payload.length > 0) && (!headerAccessor.getCommand().isBodyAllowed())) {
          throw new StompConversionException(
              headerAccessor.getCommand()
                  + " shouldn't have a payload: length="
                  + payload.length
                  + ", headers="
                  + headers);
        }
        headerAccessor.updateSimpMessageHeadersFromStompHeaders();
        headerAccessor.setLeaveMutable(true);
        decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
        if (logger.isDebugEnabled()) {
          logger.debug("Decoded " + decodedMessage);
        }
      } else {
        if (logger.isTraceEnabled()) {
          logger.trace("Received incomplete frame. Resetting buffer.");
        }
        if (headers != null && headerAccessor != null) {
          String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
          @SuppressWarnings("unchecked")
          MultiValueMap<String, String> map =
              (MultiValueMap<String, String>) headerAccessor.getHeader(name);
          if (map != null) {
            headers.putAll(map);
          }
        }
        buffer.reset();
      }
    } else {
      if (logger.isTraceEnabled()) {
        logger.trace("Decoded heartbeat");
      }
      StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
      initHeaders(headerAccessor);
      headerAccessor.setLeaveMutable(true);
      decodedMessage =
          MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
    }
    return decodedMessage;
  }
 private void writeCommand(StompHeaderAccessor headers, DataOutputStream output)
     throws IOException {
   output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET));
   output.write(LF);
 }
 private boolean isHeartbeat(StompHeaderAccessor headers) {
   return headers.getMessageType() == SimpMessageType.HEARTBEAT;
 }