@Test public void handleMessageToClientConnectAck() { StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT); connectHeaders.setHeartbeat(10000, 10000); connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1"); Message<?> connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build(); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage); Message<byte[]> connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build(); this.protocolHandler.handleMessageToClient(this.session, connectAckMessage); verifyNoMoreInteractions(this.channel); // Check CONNECTED reply assertEquals(1, this.session.getSentMessages().size()); TextMessage textMessage = (TextMessage) this.session.getSentMessages().get(0); Message<?> message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes())); StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); assertEquals("1.1", replyHeaders.getVersion()); assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); }
private void writeHeaders( StompHeaderAccessor headers, Message<byte[]> message, DataOutputStream output) throws IOException { Map<String, List<String>> stompHeaders = headers.toStompHeaderMap(); if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { logger.trace("Encoded heartbeat"); } else if (logger.isDebugEnabled()) { logger.debug("Encoded STOMP command=" + headers.getCommand() + " headers=" + stompHeaders); } for (Entry<String, List<String>> entry : stompHeaders.entrySet()) { byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers); for (String value : entry.getValue()) { output.write(key); output.write(COLON); output.write(getUtf8BytesEscapingIfNecessary(value, headers)); output.write(LF); } } if ((headers.getCommand() == StompCommand.SEND) || (headers.getCommand() == StompCommand.MESSAGE) || (headers.getCommand() == StompCommand.ERROR)) { output.write("content-length:".getBytes(UTF8_CHARSET)); output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET)); output.write(LF); } }
private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) { if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) { return escape(input).getBytes(UTF8_CHARSET); } else { return input.getBytes(UTF8_CHARSET); } }
@Test public void handleMessageFromClient() { TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) .headers( "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000") .build(); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verify(this.channel).send(this.messageCaptor.capture()); Message<?> actual = this.messageCaptor.getValue(); assertNotNull(actual); StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual); assertEquals(StompCommand.CONNECT, headers.getCommand()); assertEquals("s1", headers.getSessionId()); assertEquals("joe", headers.getUser().getName()); assertEquals("guest", headers.getLogin()); assertEquals("PROTECTED", headers.getPasscode()); assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); assertEquals(new HashSet<>(Arrays.asList("1.1", "1.0")), headers.getAcceptVersion()); assertEquals(0, this.session.getSentMessages().size()); }
@Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setAcceptVersion("1.1,1.2"); headers.setHeartbeat(0, 0); Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); TextMessage textMessage = new TextMessage(new String(this.encoder.encode(message), UTF_8)); session.sendMessage(textMessage); }
@SubscribeMapping("/topic/activity") @SendTo("/topic/tracker") public ActivityDTO sendActivity( @Payload ActivityDTO activityDTO, StompHeaderAccessor stompHeaderAccessor, Principal principal) { activityDTO.setUserLogin(SecurityUtils.getCurrentLogin()); activityDTO.setUserLogin(principal.getName()); activityDTO.setSessionId(stompHeaderAccessor.getSessionId()); activityDTO.setIpAddress(stompHeaderAccessor.getSessionAttributes().get(IP_ADDRESS).toString()); activityDTO.setTime(dateTimeFormatter.print(Calendar.getInstance().getTimeInMillis())); log.debug("Sending user tracking data {}", activityDTO); return activityDTO; }
@Override protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) throws Exception { ByteBuffer payload = ByteBuffer.wrap(textMessage.getPayload().getBytes(UTF_8)); List<Message<byte[]>> messages = this.decoder.decode(payload); for (Message message : messages) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED.equals(headers.getCommand())) { this.stompMessageHandler.afterConnected(session, headers); } else if (StompCommand.MESSAGE.equals(headers.getCommand())) { this.stompMessageHandler.handleMessage(message); } else if (StompCommand.RECEIPT.equals(headers.getCommand())) { this.stompMessageHandler.handleReceipt(headers.getReceiptId()); } else if (StompCommand.ERROR.equals(headers.getCommand())) { this.stompMessageHandler.handleError(message); } else if (StompCommand.ERROR.equals(headers.getCommand())) { this.stompMessageHandler.afterDisconnected(); } else { LOGGER.debug("Unhandled message " + message); } } }
/** * Encodes the given payload and headers into a {@code byte[]}. * * @param headers the headers * @param payload the payload * @return the encoded message */ public byte[] encode(Map<String, Object> headers, byte[] payload) { Assert.notNull(headers, "'headers' is required"); Assert.notNull(payload, "'payload' is required"); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(128 + payload.length); DataOutputStream output = new DataOutputStream(baos); if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) { logger.trace("Encoded heartbeat"); output.write(StompDecoder.HEARTBEAT_PAYLOAD); } else { StompCommand command = StompHeaderAccessor.getCommand(headers); Assert.notNull(command, "Missing STOMP command: " + headers); output.write(command.toString().getBytes(StompDecoder.UTF8_CHARSET)); output.write(LF); writeHeaders(command, headers, payload, output); output.write(LF); writeBody(payload, output); output.write((byte) 0); } return baos.toByteArray(); } catch (IOException e) { throw new StompConversionException( "Failed to encode STOMP frame, headers=" + headers + ".", e); } }
public WebSocketMessage<?> encode( Message<byte[]> message, Class<? extends WebSocketSession> sessionType) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Assert.notNull(accessor, "No StompHeaderAccessor available"); byte[] payload = message.getPayload(); byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload); boolean useBinary = (payload.length > 0 && !(SockJsSession.class.isAssignableFrom(sessionType)) && MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith( accessor.getContentType())); return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes)); }
private byte[] readPayload(ByteBuffer buffer, StompHeaderAccessor headerAccessor) { Integer contentLength; try { contentLength = headerAccessor.getContentLength(); } catch (NumberFormatException ex) { logger.warn("Ignoring invalid content-length: '" + headerAccessor); contentLength = null; } if (contentLength != null && contentLength >= 0) { if (buffer.remaining() > contentLength) { byte[] payload = new byte[contentLength]; buffer.get(payload); if (buffer.get() != 0) { throw new StompConversionException("Frame must be terminated with a null octet"); } return payload; } else { return null; } } else { ByteArrayOutputStream payload = new ByteArrayOutputStream(256); while (buffer.remaining() > 0) { byte b = buffer.get(); if (b == 0) { return payload.toByteArray(); } else { payload.write(b); } } } return null; }
private void readHeaders(ByteBuffer buffer, StompHeaderAccessor headerAccessor) { while (true) { ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256); while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) { headerStream.write(buffer.get()); } if (headerStream.size() > 0) { String header = new String(headerStream.toByteArray(), UTF8_CHARSET); int colonIndex = header.indexOf(':'); if ((colonIndex <= 0) || (colonIndex == header.length() - 1)) { if (buffer.remaining() > 0) { throw new StompConversionException( "Illegal header: '" + header + "'. A header must be of the form <name>:<value>"); } } else { String headerName = unescape(header.substring(0, colonIndex)); String headerValue = unescape(header.substring(colonIndex + 1)); try { headerAccessor.addNativeHeader(headerName, headerValue); } catch (InvalidMimeTypeException ex) { if (buffer.remaining() > 0) { throw ex; } } } } else { break; } } }
@Override public void handleException( StompSession session, StompCommand command, StompHeaders headers, byte[] payload, Throwable exception) { Message<byte[]> message = MessageBuilder.createMessage( payload, StompHeaderAccessor.create(command, headers).getMessageHeaders()); logger.error( "The exception for session [" + session + "] on message [" + message + "]", exception); }
@Test public void handleMessageToClientConnected() { UserSessionRegistry registry = new DefaultUserSessionRegistry(); this.protocolHandler.setUserSessionRegistry(registry); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0); assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload()); assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe")); }
@Override public void onApplicationEvent(ApplicationEvent event) { if (event instanceof SessionSubscribeEvent) { SessionSubscribeEvent connect = (SessionSubscribeEvent) event; StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(connect.getMessage()); HistoryService service = new HistoryService(); String dest = headerAccessor.getDestination(); Long diagramId; try { diagramId = Long.parseLong(dest.substring(dest.lastIndexOf('/') + 1)); HistorySession session = service.insertSession( connect.getUser().getName(), diagramId, headerAccessor.getSubscriptionId()); HistoryModel model = new HistoryModel( connect.getUser().getName(), session.getDiagram().getName(), session.getTimeStart(), session.getTimeFinish()); template.convertAndSend("/topic/diagram/" + diagramId + "/history", model); } catch (NumberFormatException ex) { return; } } else if (event instanceof SessionUnsubscribeEvent) { SessionUnsubscribeEvent connect = (SessionUnsubscribeEvent) event; StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(connect.getMessage()); HistoryService service = new HistoryService(); List<HistorySession> sessions = service.updateSession(connect.getUser().getName(), headerAccessor.getSubscriptionId()); for (HistorySession session : sessions) { HistoryModel model = new HistoryModel( connect.getUser().getName(), session.getDiagram().getName(), session.getTimeStart(), session.getTimeFinish()); template.convertAndSend( "/topic/diagram/" + session.getDiagram().getDiagramId() + "/history", model); } } }
/** * Encodes the given STOMP {@code message} into a {@code byte[]} * * @param message The message to encode * @return The encoded message */ public byte[] encode(Message<byte[]> message) { try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream output = new DataOutputStream(baos); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (isHeartbeat(headers)) { output.write(message.getPayload()); } else { writeCommand(headers, output); writeHeaders(headers, message, output); output.write(LF); writeBody(message, output); output.write((byte) 0); } return baos.toByteArray(); } catch (IOException e) { throw new StompConversionException("Failed to encode STOMP frame", e); } }
private void writeHeaders( StompCommand command, Map<String, Object> headers, byte[] payload, DataOutputStream output) throws IOException { @SuppressWarnings("unchecked") Map<String, List<String>> nativeHeaders = (Map<String, List<String>>) headers.get(NativeMessageHeaderAccessor.NATIVE_HEADERS); if (logger.isTraceEnabled()) { logger.trace("Encoding STOMP " + command + ", headers=" + nativeHeaders + "."); } if (nativeHeaders == null) { return; } boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.CONNECTED); for (Entry<String, List<String>> entry : nativeHeaders.entrySet()) { byte[] key = encodeHeaderString(entry.getKey(), shouldEscape); List<String> values = entry.getValue(); if (StompHeaderAccessor.STOMP_PASSCODE_HEADER.equals(entry.getKey())) { values = Arrays.asList(StompHeaderAccessor.getPasscode(headers)); } for (String value : values) { output.write(key); output.write(COLON); output.write(encodeHeaderString(value, shouldEscape)); output.write(LF); } } if (command.requiresContentLength()) { int contentLength = payload.length; output.write("content-length:".getBytes(StompDecoder.UTF8_CHARSET)); output.write(Integer.toString(contentLength).getBytes(StompDecoder.UTF8_CHARSET)); output.write(LF); } }
/** Decode a single STOMP frame from the given {@code buffer} into a {@link Message}. */ private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) { Message<byte[]> decodedMessage = null; skipLeadingEol(buffer); buffer.mark(); String command = readCommand(buffer); if (command.length() > 0) { StompHeaderAccessor headerAccessor = null; byte[] payload = null; if (buffer.remaining() > 0) { StompCommand stompCommand = StompCommand.valueOf(command); headerAccessor = StompHeaderAccessor.create(stompCommand); initHeaders(headerAccessor); readHeaders(buffer, headerAccessor); payload = readPayload(buffer, headerAccessor); } if (payload != null) { if ((payload.length > 0) && (!headerAccessor.getCommand().isBodyAllowed())) { throw new StompConversionException( headerAccessor.getCommand() + " shouldn't have a payload: length=" + payload.length + ", headers=" + headers); } headerAccessor.updateSimpMessageHeadersFromStompHeaders(); headerAccessor.setLeaveMutable(true); decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders()); if (logger.isDebugEnabled()) { logger.debug("Decoded " + decodedMessage); } } else { if (logger.isTraceEnabled()) { logger.trace("Received incomplete frame. Resetting buffer."); } if (headers != null && headerAccessor != null) { String name = NativeMessageHeaderAccessor.NATIVE_HEADERS; @SuppressWarnings("unchecked") MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name); if (map != null) { headers.putAll(map); } } buffer.reset(); } } else { if (logger.isTraceEnabled()) { logger.trace("Decoded heartbeat"); } StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat(); initHeaders(headerAccessor); headerAccessor.setLeaveMutable(true); decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders()); } return decodedMessage; }
private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); output.write(LF); }
private boolean isHeartbeat(StompHeaderAccessor headers) { return headers.getMessageType() == SimpMessageType.HEARTBEAT; }