private void writeHeaders( StompHeaderAccessor headers, Message<byte[]> message, DataOutputStream output) throws IOException { Map<String, List<String>> stompHeaders = headers.toStompHeaderMap(); if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { logger.trace("Encoded heartbeat"); } else if (logger.isDebugEnabled()) { logger.debug("Encoded STOMP command=" + headers.getCommand() + " headers=" + stompHeaders); } for (Entry<String, List<String>> entry : stompHeaders.entrySet()) { byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers); for (String value : entry.getValue()) { output.write(key); output.write(COLON); output.write(getUtf8BytesEscapingIfNecessary(value, headers)); output.write(LF); } } if ((headers.getCommand() == StompCommand.SEND) || (headers.getCommand() == StompCommand.MESSAGE) || (headers.getCommand() == StompCommand.ERROR)) { output.write("content-length:".getBytes(UTF8_CHARSET)); output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET)); output.write(LF); } }
@Override protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) throws Exception { ByteBuffer payload = ByteBuffer.wrap(textMessage.getPayload().getBytes(UTF_8)); List<Message<byte[]>> messages = this.decoder.decode(payload); for (Message message : messages) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED.equals(headers.getCommand())) { this.stompMessageHandler.afterConnected(session, headers); } else if (StompCommand.MESSAGE.equals(headers.getCommand())) { this.stompMessageHandler.handleMessage(message); } else if (StompCommand.RECEIPT.equals(headers.getCommand())) { this.stompMessageHandler.handleReceipt(headers.getReceiptId()); } else if (StompCommand.ERROR.equals(headers.getCommand())) { this.stompMessageHandler.handleError(message); } else if (StompCommand.ERROR.equals(headers.getCommand())) { this.stompMessageHandler.afterDisconnected(); } else { LOGGER.debug("Unhandled message " + message); } } }
private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) { if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) { return escape(input).getBytes(UTF8_CHARSET); } else { return input.getBytes(UTF8_CHARSET); } }
/** * Encodes the given payload and headers into a {@code byte[]}. * * @param headers the headers * @param payload the payload * @return the encoded message */ public byte[] encode(Map<String, Object> headers, byte[] payload) { Assert.notNull(headers, "'headers' is required"); Assert.notNull(payload, "'payload' is required"); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(128 + payload.length); DataOutputStream output = new DataOutputStream(baos); if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) { logger.trace("Encoded heartbeat"); output.write(StompDecoder.HEARTBEAT_PAYLOAD); } else { StompCommand command = StompHeaderAccessor.getCommand(headers); Assert.notNull(command, "Missing STOMP command: " + headers); output.write(command.toString().getBytes(StompDecoder.UTF8_CHARSET)); output.write(LF); writeHeaders(command, headers, payload, output); output.write(LF); writeBody(payload, output); output.write((byte) 0); } return baos.toByteArray(); } catch (IOException e) { throw new StompConversionException( "Failed to encode STOMP frame, headers=" + headers + ".", e); } }
@Test public void handleMessageFromClient() { TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) .headers( "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000") .build(); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verify(this.channel).send(this.messageCaptor.capture()); Message<?> actual = this.messageCaptor.getValue(); assertNotNull(actual); StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual); assertEquals(StompCommand.CONNECT, headers.getCommand()); assertEquals("s1", headers.getSessionId()); assertEquals("joe", headers.getUser().getName()); assertEquals("guest", headers.getLogin()); assertEquals("PROTECTED", headers.getPasscode()); assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); assertEquals(new HashSet<>(Arrays.asList("1.1", "1.0")), headers.getAcceptVersion()); assertEquals(0, this.session.getSentMessages().size()); }
@Test public void handleMessageToClientConnectAck() { StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT); connectHeaders.setHeartbeat(10000, 10000); connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1"); Message<?> connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build(); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage); Message<byte[]> connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build(); this.protocolHandler.handleMessageToClient(this.session, connectAckMessage); verifyNoMoreInteractions(this.channel); // Check CONNECTED reply assertEquals(1, this.session.getSentMessages().size()); TextMessage textMessage = (TextMessage) this.session.getSentMessages().get(0); Message<?> message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes())); StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); assertEquals("1.1", replyHeaders.getVersion()); assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); }
private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); output.write(LF); }
/** Decode a single STOMP frame from the given {@code buffer} into a {@link Message}. */ private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) { Message<byte[]> decodedMessage = null; skipLeadingEol(buffer); buffer.mark(); String command = readCommand(buffer); if (command.length() > 0) { StompHeaderAccessor headerAccessor = null; byte[] payload = null; if (buffer.remaining() > 0) { StompCommand stompCommand = StompCommand.valueOf(command); headerAccessor = StompHeaderAccessor.create(stompCommand); initHeaders(headerAccessor); readHeaders(buffer, headerAccessor); payload = readPayload(buffer, headerAccessor); } if (payload != null) { if ((payload.length > 0) && (!headerAccessor.getCommand().isBodyAllowed())) { throw new StompConversionException( headerAccessor.getCommand() + " shouldn't have a payload: length=" + payload.length + ", headers=" + headers); } headerAccessor.updateSimpMessageHeadersFromStompHeaders(); headerAccessor.setLeaveMutable(true); decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders()); if (logger.isDebugEnabled()) { logger.debug("Decoded " + decodedMessage); } } else { if (logger.isTraceEnabled()) { logger.trace("Received incomplete frame. Resetting buffer."); } if (headers != null && headerAccessor != null) { String name = NativeMessageHeaderAccessor.NATIVE_HEADERS; @SuppressWarnings("unchecked") MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name); if (map != null) { headers.putAll(map); } } buffer.reset(); } } else { if (logger.isTraceEnabled()) { logger.trace("Decoded heartbeat"); } StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat(); initHeaders(headerAccessor); headerAccessor.setLeaveMutable(true); decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders()); } return decodedMessage; }