@Test public void testRandomF() { RandomGenerator random = RandomManager.getRandom(); float[] vec1 = VectorMath.randomVectorF(10, random); float[] vec2 = VectorMath.randomVectorF(10, random); assertEquals(10, vec1.length); assertEquals(10, vec2.length); assertFalse(Arrays.equals(vec1, vec2)); }
@Override protected MRPipeline createPipeline() throws IOException { JobStepConfig stepConfig = getConfig(); ClusterSettings settings = ClusterSettings.create(ConfigUtils.getDefaultConfig()); String instanceDir = stepConfig.getInstanceDir(); int generationID = stepConfig.getGenerationID(); int iteration = stepConfig.getIteration(); String prefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID); String outputKey = prefix + String.format("sketch/%d/", iteration); if (!validOutputPath(outputKey)) { return null; } // get normalized vectors String inputKey = prefix + "normalized/"; MRPipeline p = createBasicPipeline(DistanceToClosestFn.class); AvroType<Pair<Integer, RealVector>> inputType = Avros.pairs(Avros.ints(), MLAvros.vector()); PCollection<Pair<Integer, RealVector>> in = p.read(avroInput(inputKey, inputType)); // either create or load the set of currently chosen k-sketch vectors // they are stored in a KSketchIndex object DistanceToClosestFn<RealVector> distanceToClosestFn; UpdateIndexFn updateIndexFn; if (iteration == 1) { // Iteration 1 is the first real iteration; iteration 0 contains initial state KSketchIndex index = createInitialIndex(settings, in); distanceToClosestFn = new DistanceToClosestFn<>(index); updateIndexFn = new UpdateIndexFn(index); } else { // Get the index location from the previous iteration String previousIndexKey = prefix + String.format("sketch/%d/", iteration - 1); distanceToClosestFn = new DistanceToClosestFn<>(previousIndexKey); updateIndexFn = new UpdateIndexFn(previousIndexKey); } // compute distance of each vector in dataset to closest vector in k-sketch PTable<Integer, Pair<RealVector, Double>> weighted = in.parallelDo( "computeDistances", distanceToClosestFn, Avros.tableOf(Avros.ints(), Avros.pairs(MLAvros.vector(), Avros.doubles()))); // run weighted reservoir sampling on the vector to select another group of // settings.getSketchPoints() // to add to the k-sketch PTable<Integer, RealVector> kSketchSample = ReservoirSampling.groupedWeightedSample( weighted, settings.getSketchPoints(), RandomManager.getRandom()); // update the KSketchIndex with the newly-chosen vectors kSketchSample .parallelDo("updateIndex", updateIndexFn, Serializables.avro(KSketchIndex.class)) .write(avroOutput(outputKey)); return p; }
@Test public void testLSH() { RandomGenerator random = RandomManager.getRandom(); Mean avgPercentTopRecsConsidered = new Mean(); Mean avgNDCG = new Mean(); Mean avgPercentAllItemsConsidered = new Mean(); for (int iteration = 0; iteration < ITERATIONS; iteration++) { LongObjectMap<float[]> Y = new LongObjectMap<float[]>(); for (int i = 0; i < NUM_ITEMS; i++) { Y.put(i, RandomUtils.randomUnitVector(NUM_FEATURES, random)); } float[] userVec = RandomUtils.randomUnitVector(NUM_FEATURES, random); double[] results = doTestRandomVecs(Y, userVec); double percentTopRecsConsidered = results[0]; double ndcg = results[1]; double percentAllItemsConsidered = results[2]; log.info( "Considered {}% of all candidates, {} nDCG, got {}% recommendations correct", 100 * percentAllItemsConsidered, ndcg, 100 * percentTopRecsConsidered); avgPercentTopRecsConsidered.increment(percentTopRecsConsidered); avgNDCG.increment(ndcg); avgPercentAllItemsConsidered.increment(percentAllItemsConsidered); } log.info("{}", avgPercentTopRecsConsidered.getResult()); log.info("{}", avgNDCG.getResult()); log.info("{}", avgPercentAllItemsConsidered.getResult()); assertTrue(avgPercentTopRecsConsidered.getResult() > 0.8); assertTrue(avgNDCG.getResult() > 0.8); assertTrue(avgPercentAllItemsConsidered.getResult() < 0.09); }
@Test public void testLSHEffect() { RandomGenerator random = RandomManager.getRandom(); PoissonDistribution itemPerUserDist = new PoissonDistribution( random, 20, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS); int features = 20; ALSServingModel mainModel = new ALSServingModel(features, true, 1.0, null); ALSServingModel lshModel = new ALSServingModel(features, true, 0.5, null); int userItemCount = 20000; for (int user = 0; user < userItemCount; user++) { String userID = "U" + user; float[] vec = VectorMath.randomVectorF(features, random); mainModel.setUserVector(userID, vec); lshModel.setUserVector(userID, vec); int itemsPerUser = itemPerUserDist.sample(); Collection<String> knownIDs = new ArrayList<>(itemsPerUser); for (int i = 0; i < itemsPerUser; i++) { knownIDs.add("I" + random.nextInt(userItemCount)); } mainModel.addKnownItems(userID, knownIDs); lshModel.addKnownItems(userID, knownIDs); } for (int item = 0; item < userItemCount; item++) { String itemID = "I" + item; float[] vec = VectorMath.randomVectorF(features, random); mainModel.setItemVector(itemID, vec); lshModel.setItemVector(itemID, vec); } int numRecs = 10; Mean meanMatchLength = new Mean(); for (int user = 0; user < userItemCount; user++) { String userID = "U" + user; List<Pair<String, Double>> mainRecs = mainModel.topN(new DotsFunction(mainModel.getUserVector(userID)), null, numRecs, null); List<Pair<String, Double>> lshRecs = lshModel.topN(new DotsFunction(lshModel.getUserVector(userID)), null, numRecs, null); int i = 0; while (i < lshRecs.size() && i < mainRecs.size() && lshRecs.get(i).equals(mainRecs.get(i))) { i++; } meanMatchLength.increment(i); } log.info("Mean matching prefix: {}", meanMatchLength.getResult()); assertTrue(meanMatchLength.getResult() >= 4.0); meanMatchLength.clear(); for (int item = 0; item < userItemCount; item++) { String itemID = "I" + item; List<Pair<String, Double>> mainRecs = mainModel.topN( new CosineAverageFunction(mainModel.getItemVector(itemID)), null, numRecs, null); List<Pair<String, Double>> lshRecs = lshModel.topN( new CosineAverageFunction(lshModel.getItemVector(itemID)), null, numRecs, null); int i = 0; while (i < lshRecs.size() && i < mainRecs.size() && lshRecs.get(i).equals(mainRecs.get(i))) { i++; } meanMatchLength.increment(i); } log.info("Mean matching prefix: {}", meanMatchLength.getResult()); assertTrue(meanMatchLength.getResult() >= 5.0); }