Exemple #1
0
  @Test
  public void testALSSpeed() throws Exception {
    Map<String, Object> overlayConfig = new HashMap<>();
    overlayConfig.put("oryx.speed.model-manager-class", ALSSpeedModelManager.class.getName());
    overlayConfig.put("oryx.speed.streaming.generation-interval-sec", 5);
    overlayConfig.put("oryx.als.hyperparams.features", 2);
    Config config = ConfigUtils.overlayOn(overlayConfig, getConfig());

    startMessaging();

    List<Pair<String, String>> updates =
        startServerProduceConsumeTopics(
            config, new MockALSInputGenerator(), new MockALSModelUpdateGenerator(), 9, 10);

    if (log.isDebugEnabled()) {
      for (Pair<String, String> update : updates) {
        log.debug("{}", update);
      }
    }

    // 10 original updates. 9 generate just 1 update since user or item is new.
    assertEquals(19, updates.size());
    assertEquals("MODEL", updates.get(0).getFirst());
    assertEquals(
        2,
        Integer.parseInt(
            AppPMMLUtils.getExtensionValue(
                PMMLUtils.fromString(updates.get(0).getSecond()), "features")));

    for (int i = 1; i <= 9; i++) {
      assertEquals("UP", updates.get(i).getFirst());
      List<?> update = MAPPER.readValue(updates.get(i).getSecond(), List.class);
      boolean isX = "X".equals(update.get(0).toString());
      String id = update.get(1).toString();
      float[] expected =
          (isX ? MockALSModelUpdateGenerator.X : MockALSModelUpdateGenerator.Y).get(id);
      assertArrayEquals(expected, MAPPER.convertValue(update.get(2), float[].class));
      @SuppressWarnings("unchecked")
      Collection<String> knownUsersItems = (Collection<String>) update.get(3);
      Collection<String> expectedKnownUsersItems =
          (isX ? MockALSModelUpdateGenerator.A : MockALSModelUpdateGenerator.At).get(id);
      assertTrue(knownUsersItems.containsAll(expectedKnownUsersItems));
      assertTrue(expectedKnownUsersItems.containsAll(knownUsersItems));
    }

    /*
     * User 100 - 104 are solutions to eye(5)*Y*pinv(Y'*Y), but default scaling
     * will produce values that are 3/4 of this since they are brand new.
     * That is, it's really the solution to (0.75*eye(5))*Y*pinv(Y'*Y)
     * Likewise 105 - 108 are (0.75*eye(4))*X*pinv(X'*X)
     */

    Map<String, float[]> X =
        MockALSModelUpdateGenerator.buildMatrix(
            100,
            new float[][] {
              {-0.20859924f, 0.25232133f},
              {-0.22472803f, -0.1929485f},
              {-0.15592135f, 0.3977631f},
              {-0.3006522f, -0.12239703f},
              {-0.09205295f, -0.37471837f},
            });
    Map<String, float[]> Y =
        MockALSModelUpdateGenerator.buildMatrix(
            105,
            new float[][] {
              {-0.19663288f, 0.09574106f},
              {-0.23840417f, -0.50850725f},
              {-0.34360975f, 0.2466687f},
              {-0.060204573f, 0.29311115f},
            });

    for (int i = 10; i <= 18; i++) {
      assertEquals("UP", updates.get(i).getFirst());
      List<?> update = MAPPER.readValue(updates.get(i).getSecond(), List.class);
      boolean isX = "X".equals(update.get(0).toString());
      String id = update.get(1).toString();
      float[] expected = (isX ? X : Y).get(id);
      assertArrayEquals(expected, MAPPER.convertValue(update.get(2), float[].class), 1.0e-5f);
      String otherID = ALSUtilsTest.idToStringID(ALSUtilsTest.stringIDtoID(id) - 99);
      @SuppressWarnings("unchecked")
      Collection<String> knownUsersItems = (Collection<String>) update.get(3);
      assertEquals(1, knownUsersItems.size());
      assertEquals(otherID, knownUsersItems.iterator().next());
    }
  }