public static ThreadPoolTaskScheduler createTaskScheduler(int poolSize) {
   ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
   scheduler.setPoolSize(poolSize);
   scheduler.setRejectedExecutionHandler(new CallerRunsPolicy());
   scheduler.afterPropertiesSet();
   return scheduler;
 }
 public RabbitTestMessageBus(ConnectionFactory connectionFactory, Codec codec) {
   RabbitMessageBus messageBus = new RabbitMessageBus(connectionFactory, codec);
   GenericApplicationContext context = new GenericApplicationContext();
   ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
   scheduler.setPoolSize(1);
   scheduler.afterPropertiesSet();
   context
       .getBeanFactory()
       .registerSingleton(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME, scheduler);
   context.refresh();
   messageBus.setApplicationContext(context);
   messageBus.setLongStringLimit(8193);
   this.setMessageBus(messageBus);
   assertEquals(
       8193,
       TestUtils.getPropertyValue(
           messageBus, "inboundMessagePropertiesConverter.longStringLimit"));
   this.rabbitAdmin = new RabbitAdmin(connectionFactory);
 }
  public void testObtainConnectionIds(AbstractServerConnectionFactory serverFactory)
      throws Exception {
    final List<IpIntegrationEvent> events =
        Collections.synchronizedList(new ArrayList<IpIntegrationEvent>());
    int expectedEvents =
        serverFactory instanceof TcpNetServerConnectionFactory
            ? 7 // Listening, + OPEN, CLOSE, EXCEPTION for each side
            : 5; // Listening, + OPEN, CLOSE (but we *might* get exceptions, depending on timing).
    final CountDownLatch serverListeningLatch = new CountDownLatch(1);
    final CountDownLatch eventLatch = new CountDownLatch(expectedEvents);
    ApplicationEventPublisher publisher =
        new ApplicationEventPublisher() {

          @Override
          public void publishEvent(ApplicationEvent event) {
            LogFactory.getLog(this.getClass()).trace("Received: " + event);
            events.add((IpIntegrationEvent) event);
            if (event instanceof TcpConnectionServerListeningEvent) {
              serverListeningLatch.countDown();
            }
            eventLatch.countDown();
          }

          @Override
          public void publishEvent(Object event) {}
        };
    serverFactory.setBeanName("serverFactory");
    serverFactory.setApplicationEventPublisher(publisher);
    serverFactory = spy(serverFactory);
    final CountDownLatch serverConnectionInitLatch = new CountDownLatch(1);
    doAnswer(
            invocation -> {
              Object result = invocation.callRealMethod();
              serverConnectionInitLatch.countDown();
              return result;
            })
        .when(serverFactory)
        .wrapConnection(any(TcpConnectionSupport.class));
    ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
    scheduler.setPoolSize(10);
    scheduler.afterPropertiesSet();
    BeanFactory bf = mock(BeanFactory.class);
    when(bf.containsBean(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME)).thenReturn(true);
    when(bf.getBean(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME, TaskScheduler.class))
        .thenReturn(scheduler);
    serverFactory.setBeanFactory(bf);
    TcpReceivingChannelAdapter adapter = new TcpReceivingChannelAdapter();
    adapter.setOutputChannel(new NullChannel());
    adapter.setConnectionFactory(serverFactory);
    adapter.start();
    assertTrue("Listening event not received", serverListeningLatch.await(10, TimeUnit.SECONDS));
    assertThat(events.get(0), instanceOf(TcpConnectionServerListeningEvent.class));
    assertThat(
        ((TcpConnectionServerListeningEvent) events.get(0)).getPort(),
        equalTo(serverFactory.getPort()));
    int port = serverFactory.getPort();
    TcpNetClientConnectionFactory clientFactory =
        new TcpNetClientConnectionFactory("localhost", port);
    clientFactory.registerListener(message -> false);
    clientFactory.setBeanName("clientFactory");
    clientFactory.setApplicationEventPublisher(publisher);
    clientFactory.start();
    TcpConnectionSupport client = clientFactory.getConnection();
    List<String> clients = clientFactory.getOpenConnectionIds();
    assertEquals(1, clients.size());
    assertTrue(clients.contains(client.getConnectionId()));
    assertTrue(
        "Server connection failed to register",
        serverConnectionInitLatch.await(1, TimeUnit.SECONDS));
    List<String> servers = serverFactory.getOpenConnectionIds();
    assertEquals(1, servers.size());
    assertTrue(serverFactory.closeConnection(servers.get(0)));
    servers = serverFactory.getOpenConnectionIds();
    assertEquals(0, servers.size());
    int n = 0;
    clients = clientFactory.getOpenConnectionIds();
    while (n++ < 100 && clients.size() > 0) {
      Thread.sleep(100);
      clients = clientFactory.getOpenConnectionIds();
    }
    assertEquals(0, clients.size());
    assertTrue(eventLatch.await(10, TimeUnit.SECONDS));
    assertThat(
        "Expected at least " + expectedEvents + " events; got: " + events.size() + " : " + events,
        events.size(),
        greaterThanOrEqualTo(expectedEvents));

    FooEvent event = new FooEvent(client, "foo");
    client.publishEvent(event);
    assertThat(
        "Expected at least " + expectedEvents + " events; got: " + events.size() + " : " + events,
        events.size(),
        greaterThanOrEqualTo(expectedEvents + 1));

    try {
      event = new FooEvent(mock(TcpConnectionSupport.class), "foo");
      client.publishEvent(event);
      fail("Expected exception");
    } catch (IllegalArgumentException e) {
      assertTrue("Can only publish events with this as the source".equals(e.getMessage()));
    }

    SocketAddress address = serverFactory.getServerSocketAddress();
    if (address instanceof InetSocketAddress) {
      InetSocketAddress inetAddress = (InetSocketAddress) address;
      assertEquals(port, inetAddress.getPort());
    }
    serverFactory.stop();
    scheduler.shutdown();
  }
  public static void main(String[] args) throws Exception {

    // Modify host and port below to match wherever StompWebSocketServer.java is running!!
    // When StompWebSocketServer starts it prints the selected available

    String host = "localhost";
    if (args.length > 0) {
      host = args[0];
    }

    int port = 37232;
    if (args.length > 1) {
      port = Integer.valueOf(args[1]);
    }

    String homeUrl = "http://{host}:{port}/home";
    logger.debug("Sending warm-up HTTP request to " + homeUrl);
    HttpStatus status =
        new RestTemplate().getForEntity(homeUrl, Void.class, host, port).getStatusCode();
    Assert.state(status == HttpStatus.OK);

    final CountDownLatch connectLatch = new CountDownLatch(NUMBER_OF_USERS);
    final CountDownLatch subscribeLatch = new CountDownLatch(NUMBER_OF_USERS);
    final CountDownLatch messageLatch = new CountDownLatch(NUMBER_OF_USERS);
    final CountDownLatch disconnectLatch = new CountDownLatch(NUMBER_OF_USERS);

    final AtomicReference<Throwable> failure = new AtomicReference<>();

    StandardWebSocketClient webSocketClient = new StandardWebSocketClient();

    HttpClient jettyHttpClient = new HttpClient();
    jettyHttpClient.setMaxConnectionsPerDestination(1000);
    jettyHttpClient.setExecutor(new QueuedThreadPool(1000));
    jettyHttpClient.start();

    List<Transport> transports = new ArrayList<>();
    transports.add(new WebSocketTransport(webSocketClient));
    transports.add(new JettyXhrTransport(jettyHttpClient));

    SockJsClient sockJsClient = new SockJsClient(transports);

    try {
      ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
      taskScheduler.afterPropertiesSet();

      String stompUrl = "ws://{host}:{port}/stomp";
      WebSocketStompClient stompClient = new WebSocketStompClient(sockJsClient);
      stompClient.setMessageConverter(new StringMessageConverter());
      stompClient.setTaskScheduler(taskScheduler);
      stompClient.setDefaultHeartbeat(new long[] {0, 0});

      logger.debug("Connecting and subscribing " + NUMBER_OF_USERS + " users ");
      StopWatch stopWatch = new StopWatch("STOMP Broker Relay WebSocket Load Tests");
      stopWatch.start();

      List<ConsumerStompSessionHandler> consumers = new ArrayList<>();
      for (int i = 0; i < NUMBER_OF_USERS; i++) {
        consumers.add(
            new ConsumerStompSessionHandler(
                BROADCAST_MESSAGE_COUNT,
                connectLatch,
                subscribeLatch,
                messageLatch,
                disconnectLatch,
                failure));
        stompClient.connect(stompUrl, consumers.get(i), host, port);
      }

      if (failure.get() != null) {
        throw new AssertionError("Test failed", failure.get());
      }
      if (!connectLatch.await(5000, TimeUnit.MILLISECONDS)) {
        fail("Not all users connected, remaining: " + connectLatch.getCount());
      }
      if (!subscribeLatch.await(5000, TimeUnit.MILLISECONDS)) {
        fail("Not all users subscribed, remaining: " + subscribeLatch.getCount());
      }

      stopWatch.stop();
      logger.debug("Finished: " + stopWatch.getLastTaskTimeMillis() + " millis");

      logger.debug(
          "Broadcasting "
              + BROADCAST_MESSAGE_COUNT
              + " messages to "
              + NUMBER_OF_USERS
              + " users ");
      stopWatch.start();

      ProducerStompSessionHandler producer =
          new ProducerStompSessionHandler(BROADCAST_MESSAGE_COUNT, failure);
      stompClient.connect(stompUrl, producer, host, port);
      stompClient.setTaskScheduler(taskScheduler);

      if (failure.get() != null) {
        throw new AssertionError("Test failed", failure.get());
      }
      if (!messageLatch.await(60 * 1000, TimeUnit.MILLISECONDS)) {
        for (ConsumerStompSessionHandler consumer : consumers) {
          if (consumer.messageCount.get() < consumer.expectedMessageCount) {
            logger.debug(consumer);
          }
        }
      }
      if (!messageLatch.await(60 * 1000, TimeUnit.MILLISECONDS)) {
        fail("Not all handlers received every message, remaining: " + messageLatch.getCount());
      }

      producer.session.disconnect();
      if (!disconnectLatch.await(5000, TimeUnit.MILLISECONDS)) {
        fail("Not all disconnects completed, remaining: " + disconnectLatch.getCount());
      }

      stopWatch.stop();
      logger.debug("Finished: " + stopWatch.getLastTaskTimeMillis() + " millis");

      System.out.println("\nPress any key to exit...");
      System.in.read();
    } catch (Throwable t) {
      t.printStackTrace();
    } finally {
      jettyHttpClient.stop();
    }

    logger.debug("Exiting");
    System.exit(0);
  }