示例#1
0
 @Test
 public void testRandomF() {
   RandomGenerator random = RandomManager.getRandom();
   float[] vec1 = VectorMath.randomVectorF(10, random);
   float[] vec2 = VectorMath.randomVectorF(10, random);
   assertEquals(10, vec1.length);
   assertEquals(10, vec2.length);
   assertFalse(Arrays.equals(vec1, vec2));
 }
示例#2
0
  @Override
  protected MRPipeline createPipeline() throws IOException {
    JobStepConfig stepConfig = getConfig();
    ClusterSettings settings = ClusterSettings.create(ConfigUtils.getDefaultConfig());

    String instanceDir = stepConfig.getInstanceDir();
    int generationID = stepConfig.getGenerationID();
    int iteration = stepConfig.getIteration();
    String prefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    String outputKey = prefix + String.format("sketch/%d/", iteration);
    if (!validOutputPath(outputKey)) {
      return null;
    }

    // get normalized vectors
    String inputKey = prefix + "normalized/";
    MRPipeline p = createBasicPipeline(DistanceToClosestFn.class);
    AvroType<Pair<Integer, RealVector>> inputType = Avros.pairs(Avros.ints(), MLAvros.vector());
    PCollection<Pair<Integer, RealVector>> in = p.read(avroInput(inputKey, inputType));

    // either create or load the set of currently chosen k-sketch vectors
    // they are stored in a KSketchIndex object
    DistanceToClosestFn<RealVector> distanceToClosestFn;
    UpdateIndexFn updateIndexFn;
    if (iteration
        == 1) { // Iteration 1 is the first real iteration; iteration 0 contains initial state
      KSketchIndex index = createInitialIndex(settings, in);
      distanceToClosestFn = new DistanceToClosestFn<>(index);
      updateIndexFn = new UpdateIndexFn(index);
    } else {
      // Get the index location from the previous iteration
      String previousIndexKey = prefix + String.format("sketch/%d/", iteration - 1);
      distanceToClosestFn = new DistanceToClosestFn<>(previousIndexKey);
      updateIndexFn = new UpdateIndexFn(previousIndexKey);
    }

    // compute distance of each vector in dataset to closest vector in k-sketch
    PTable<Integer, Pair<RealVector, Double>> weighted =
        in.parallelDo(
            "computeDistances",
            distanceToClosestFn,
            Avros.tableOf(Avros.ints(), Avros.pairs(MLAvros.vector(), Avros.doubles())));

    // run weighted reservoir sampling on the vector to select another group of
    // settings.getSketchPoints()
    // to add to the k-sketch
    PTable<Integer, RealVector> kSketchSample =
        ReservoirSampling.groupedWeightedSample(
            weighted, settings.getSketchPoints(), RandomManager.getRandom());

    // update the KSketchIndex with the newly-chosen vectors
    kSketchSample
        .parallelDo("updateIndex", updateIndexFn, Serializables.avro(KSketchIndex.class))
        .write(avroOutput(outputKey));

    return p;
  }
  @Test
  public void testLSH() {
    RandomGenerator random = RandomManager.getRandom();

    Mean avgPercentTopRecsConsidered = new Mean();
    Mean avgNDCG = new Mean();
    Mean avgPercentAllItemsConsidered = new Mean();

    for (int iteration = 0; iteration < ITERATIONS; iteration++) {

      LongObjectMap<float[]> Y = new LongObjectMap<float[]>();
      for (int i = 0; i < NUM_ITEMS; i++) {
        Y.put(i, RandomUtils.randomUnitVector(NUM_FEATURES, random));
      }
      float[] userVec = RandomUtils.randomUnitVector(NUM_FEATURES, random);

      double[] results = doTestRandomVecs(Y, userVec);
      double percentTopRecsConsidered = results[0];
      double ndcg = results[1];
      double percentAllItemsConsidered = results[2];

      log.info(
          "Considered {}% of all candidates, {} nDCG, got {}% recommendations correct",
          100 * percentAllItemsConsidered, ndcg, 100 * percentTopRecsConsidered);

      avgPercentTopRecsConsidered.increment(percentTopRecsConsidered);
      avgNDCG.increment(ndcg);
      avgPercentAllItemsConsidered.increment(percentAllItemsConsidered);
    }

    log.info("{}", avgPercentTopRecsConsidered.getResult());
    log.info("{}", avgNDCG.getResult());
    log.info("{}", avgPercentAllItemsConsidered.getResult());

    assertTrue(avgPercentTopRecsConsidered.getResult() > 0.8);
    assertTrue(avgNDCG.getResult() > 0.8);
    assertTrue(avgPercentAllItemsConsidered.getResult() < 0.09);
  }
示例#4
0
  @Test
  public void testLSHEffect() {
    RandomGenerator random = RandomManager.getRandom();
    PoissonDistribution itemPerUserDist =
        new PoissonDistribution(
            random,
            20,
            PoissonDistribution.DEFAULT_EPSILON,
            PoissonDistribution.DEFAULT_MAX_ITERATIONS);
    int features = 20;
    ALSServingModel mainModel = new ALSServingModel(features, true, 1.0, null);
    ALSServingModel lshModel = new ALSServingModel(features, true, 0.5, null);

    int userItemCount = 20000;
    for (int user = 0; user < userItemCount; user++) {
      String userID = "U" + user;
      float[] vec = VectorMath.randomVectorF(features, random);
      mainModel.setUserVector(userID, vec);
      lshModel.setUserVector(userID, vec);
      int itemsPerUser = itemPerUserDist.sample();
      Collection<String> knownIDs = new ArrayList<>(itemsPerUser);
      for (int i = 0; i < itemsPerUser; i++) {
        knownIDs.add("I" + random.nextInt(userItemCount));
      }
      mainModel.addKnownItems(userID, knownIDs);
      lshModel.addKnownItems(userID, knownIDs);
    }

    for (int item = 0; item < userItemCount; item++) {
      String itemID = "I" + item;
      float[] vec = VectorMath.randomVectorF(features, random);
      mainModel.setItemVector(itemID, vec);
      lshModel.setItemVector(itemID, vec);
    }

    int numRecs = 10;
    Mean meanMatchLength = new Mean();
    for (int user = 0; user < userItemCount; user++) {
      String userID = "U" + user;
      List<Pair<String, Double>> mainRecs =
          mainModel.topN(new DotsFunction(mainModel.getUserVector(userID)), null, numRecs, null);
      List<Pair<String, Double>> lshRecs =
          lshModel.topN(new DotsFunction(lshModel.getUserVector(userID)), null, numRecs, null);
      int i = 0;
      while (i < lshRecs.size() && i < mainRecs.size() && lshRecs.get(i).equals(mainRecs.get(i))) {
        i++;
      }
      meanMatchLength.increment(i);
    }
    log.info("Mean matching prefix: {}", meanMatchLength.getResult());
    assertTrue(meanMatchLength.getResult() >= 4.0);

    meanMatchLength.clear();
    for (int item = 0; item < userItemCount; item++) {
      String itemID = "I" + item;
      List<Pair<String, Double>> mainRecs =
          mainModel.topN(
              new CosineAverageFunction(mainModel.getItemVector(itemID)), null, numRecs, null);
      List<Pair<String, Double>> lshRecs =
          lshModel.topN(
              new CosineAverageFunction(lshModel.getItemVector(itemID)), null, numRecs, null);
      int i = 0;
      while (i < lshRecs.size() && i < mainRecs.size() && lshRecs.get(i).equals(mainRecs.get(i))) {
        i++;
      }
      meanMatchLength.increment(i);
    }
    log.info("Mean matching prefix: {}", meanMatchLength.getResult());
    assertTrue(meanMatchLength.getResult() >= 5.0);
  }