@Override protected void loadRecentModel(long mostRecentModelGeneration) throws IOException { if (mostRecentModelGeneration <= modelGeneration) { return; } if (modelGeneration == NO_GENERATION) { log.info("Most recent generation {} is the first available one", mostRecentModelGeneration); } else { log.info( "Most recent generation {} is newer than current {}", mostRecentModelGeneration, modelGeneration); } File modelPMMLFile = File.createTempFile("model-", ".pmml.gz"); modelPMMLFile.deleteOnExit(); IOUtils.delete(modelPMMLFile); Config config = ConfigUtils.getDefaultConfig(); String instanceDir = config.getString("model.instance-dir"); String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, mostRecentModelGeneration); String modelPMMLKey = generationPrefix + "model.pmml.gz"; Store.get().download(modelPMMLKey, modelPMMLFile); log.info("Loading model description from {}", modelPMMLKey); Pair<DecisionForest, Map<Integer, BiMap<String, Integer>>> forestAndCatalog = DecisionForestPMML.read(modelPMMLFile); IOUtils.delete(modelPMMLFile); log.info("Loaded model description"); modelGeneration = mostRecentModelGeneration; currentModel = new Generation(forestAndCatalog.getFirst(), forestAndCatalog.getSecond()); }
@Override protected MRPipeline createPipeline() throws IOException { JobStepConfig stepConfig = getConfig(); ClusterSettings settings = ClusterSettings.create(ConfigUtils.getDefaultConfig()); String instanceDir = stepConfig.getInstanceDir(); int generationID = stepConfig.getGenerationID(); int iteration = stepConfig.getIteration(); String prefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID); String outputKey = prefix + String.format("sketch/%d/", iteration); if (!validOutputPath(outputKey)) { return null; } // get normalized vectors String inputKey = prefix + "normalized/"; MRPipeline p = createBasicPipeline(DistanceToClosestFn.class); AvroType<Pair<Integer, RealVector>> inputType = Avros.pairs(Avros.ints(), MLAvros.vector()); PCollection<Pair<Integer, RealVector>> in = p.read(avroInput(inputKey, inputType)); // either create or load the set of currently chosen k-sketch vectors // they are stored in a KSketchIndex object DistanceToClosestFn<RealVector> distanceToClosestFn; UpdateIndexFn updateIndexFn; if (iteration == 1) { // Iteration 1 is the first real iteration; iteration 0 contains initial state KSketchIndex index = createInitialIndex(settings, in); distanceToClosestFn = new DistanceToClosestFn<>(index); updateIndexFn = new UpdateIndexFn(index); } else { // Get the index location from the previous iteration String previousIndexKey = prefix + String.format("sketch/%d/", iteration - 1); distanceToClosestFn = new DistanceToClosestFn<>(previousIndexKey); updateIndexFn = new UpdateIndexFn(previousIndexKey); } // compute distance of each vector in dataset to closest vector in k-sketch PTable<Integer, Pair<RealVector, Double>> weighted = in.parallelDo( "computeDistances", distanceToClosestFn, Avros.tableOf(Avros.ints(), Avros.pairs(MLAvros.vector(), Avros.doubles()))); // run weighted reservoir sampling on the vector to select another group of // settings.getSketchPoints() // to add to the k-sketch PTable<Integer, RealVector> kSketchSample = ReservoirSampling.groupedWeightedSample( weighted, settings.getSketchPoints(), RandomManager.getRandom()); // update the KSketchIndex with the newly-chosen vectors kSketchSample .parallelDo("updateIndex", updateIndexFn, Serializables.avro(KSketchIndex.class)) .write(avroOutput(outputKey)); return p; }
@Override protected void doPost() throws IOException { String instanceGenerationPrefix = Namespaces.getInstanceGenerationPrefix(getInstanceDir(), getGenerationID()); String outputPathKey = instanceGenerationPrefix + "trees/"; Store store = Store.get(); PMML joinedForest = null; for (String treePrefix : store.list(outputPathKey, true)) { File treeTempFile = File.createTempFile("model-", ".pmml.gz"); treeTempFile.deleteOnExit(); store.download(treePrefix, treeTempFile); PMML pmml; InputStream in = IOUtils.openMaybeDecompressing(treeTempFile); try { pmml = IOUtil.unmarshal(in); } catch (SAXException e) { throw new IOException(e); } catch (JAXBException e) { throw new IOException(e); } finally { in.close(); } IOUtils.delete(treeTempFile); if (joinedForest == null) { joinedForest = pmml; } else { MiningModel existingModel = (MiningModel) joinedForest.getModels().get(0); MiningModel nextModel = (MiningModel) pmml.getModels().get(0); existingModel .getSegmentation() .getSegments() .addAll(nextModel.getSegmentation().getSegments()); } } File tempJoinedForestFile = File.createTempFile("model-", ".pmml.gz"); tempJoinedForestFile.deleteOnExit(); OutputStream out = IOUtils.buildGZIPOutputStream(new FileOutputStream(tempJoinedForestFile)); try { IOUtil.marshal(joinedForest, out); } catch (JAXBException e) { throw new IOException(e); } finally { out.close(); } store.upload(instanceGenerationPrefix + "model.pmml.gz", tempJoinedForestFile, false); IOUtils.delete(tempJoinedForestFile); }
@Override protected void runSteps() throws IOException { String instanceDir = getInstanceDir(); int generationID = getGenerationID(); String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID); int lastGenerationID = getLastGenerationID(); File currentInputDir = Files.createTempDir(); currentInputDir.deleteOnExit(); File tempOutDir = Files.createTempDir(); tempOutDir.deleteOnExit(); try { Store store = Store.get(); store.downloadDirectory(generationPrefix + "inbound/", currentInputDir); if (lastGenerationID >= 0) { store.downloadDirectory( Namespaces.getInstanceGenerationPrefix(instanceDir, lastGenerationID) + "input/", currentInputDir); } List<Example> examples = new ArrayList<>(); Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping = new HashMap<>(); new ReadInputs(currentInputDir, examples, columnToCategoryNameToIDMapping).call(); DecisionForest forest = DecisionForest.fromExamplesWithDefault(examples); DecisionForestPMML.write( new File(tempOutDir, "model.pmml.gz"), forest, columnToCategoryNameToIDMapping); store.uploadDirectory(generationPrefix + "input/", currentInputDir, false); store.uploadDirectory(generationPrefix, tempOutDir, false); } finally { IOUtils.deleteRecursively(currentInputDir); IOUtils.deleteRecursively(tempOutDir); } }
@Override protected MRPipeline createPipeline() throws IOException { JobStepConfig stepConfig = getConfig(); Config config = ConfigUtils.getDefaultConfig(); ClusterSettings clusterSettings = ClusterSettings.create(config); String instanceDir = stepConfig.getInstanceDir(); long generationID = stepConfig.getGenerationID(); String prefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID); String outputKey = prefix + "weighted/"; if (!validOutputPath(outputKey)) { return null; } String indexKey = prefix + "sketch/" + clusterSettings.getSketchIterations(); String inputKey = prefix + "normalized/"; MRPipeline p = createBasicPipeline(ClosestSketchVectorFn.class); // first I compute the weight of each k-sketch vector, i.e., Voronoi partition // I aggregate all together and persist on disk // PCollection<ClosestSketchVectorData> weights = inputPairs(p, inputKey, MLAvros.vector()) PCollection<ClosestSketchVectorData> weights = PTables.asPTable( inputPairs(p, inputKey, MLAvros.vector()) .parallelDo( "computingSketchVectorWeights", new ClosestSketchVectorFn<RealVector>(indexKey, clusterSettings), Avros.pairs(Avros.ints(), Avros.reflects(ClosestSketchVectorData.class)))) .groupByKey(1) .combineValues(new ClosestSketchVectorAggregator(clusterSettings)) .values() .write(avroOutput(outputKey + "kSketchVectorWeights/")); // this "pipeline" takes a single ClosestSketchVectorData and returns weighted vectors // could be done outside MapReduce, but that would require me to materialize the // ClosestSketchVectorData weights .parallelDo( "generatingWeightedSketchVectors", new WeightVectorsFn(indexKey), KMeansTypes.FOLD_WEIGHTED_VECTOR) .write(avroOutput(outputKey + "weightedKSketchVectors/")); return p; }
void loadModel(int generationID, Generation currentGeneration) throws IOException { File modelPMMLFile = File.createTempFile("oryx-model", ".pmml.gz"); modelPMMLFile.deleteOnExit(); IOUtils.delete(modelPMMLFile); String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID); String modelPMMLKey = generationPrefix + "model.pmml.gz"; Store.get().download(modelPMMLKey, modelPMMLFile); log.info("Loading model description from {}", modelPMMLKey); ALSModelDescription modelDescription = ALSModelDescription.read(modelPMMLFile); IOUtils.delete(modelPMMLFile); Collection<Future<Object>> futures = new ArrayList<>(); // Limit this fairly sharply to 2 so as to not saturate the network link ExecutorService executor = ExecutorUtils.buildExecutor("LoadModel", 2); LongSet loadedUserIDs; LongSet loadedItemIDs; LongSet loadedUserIDsForKnownItems; try { loadedUserIDs = loadXOrY(generationPrefix, modelDescription, true, currentGeneration, futures, executor); loadedItemIDs = loadXOrY(generationPrefix, modelDescription, false, currentGeneration, futures, executor); if (currentGeneration.getKnownItemIDs() == null) { loadedUserIDsForKnownItems = null; } else { loadedUserIDsForKnownItems = loadKnownItemIDs( generationPrefix, modelDescription, currentGeneration, futures, executor); } loadIDMapping(generationPrefix, modelDescription, currentGeneration, futures, executor); ExecutorUtils.getResults(futures); log.info("Finished all load tasks"); } finally { ExecutorUtils.shutdownNowAndAwait(executor); } log.info("Pruning old entries..."); synchronized (lockForRecent) { removeNotUpdated( currentGeneration.getX().keySetIterator(), loadedUserIDs, recentlyActiveUsers, currentGeneration.getXLock().writeLock()); removeNotUpdated( currentGeneration.getY().keySetIterator(), loadedItemIDs, recentlyActiveItems, currentGeneration.getYLock().writeLock()); if (loadedUserIDsForKnownItems != null && currentGeneration.getKnownItemIDs() != null) { removeNotUpdated( currentGeneration.getKnownItemIDs().keySetIterator(), loadedUserIDsForKnownItems, recentlyActiveUsers, currentGeneration.getKnownItemLock().writeLock()); } this.recentlyActiveItems.clear(); this.recentlyActiveUsers.clear(); } log.info("Recomputing generation state..."); currentGeneration.recomputeState(); log.info( "All model elements loaded, {} users and {} items", currentGeneration.getNumUsers(), currentGeneration.getNumItems()); }