예제 #1
0
  @Override
  protected void loadRecentModel(long mostRecentModelGeneration) throws IOException {
    if (mostRecentModelGeneration <= modelGeneration) {
      return;
    }
    if (modelGeneration == NO_GENERATION) {
      log.info("Most recent generation {} is the first available one", mostRecentModelGeneration);
    } else {
      log.info(
          "Most recent generation {} is newer than current {}",
          mostRecentModelGeneration,
          modelGeneration);
    }

    File modelPMMLFile = File.createTempFile("model-", ".pmml.gz");
    modelPMMLFile.deleteOnExit();
    IOUtils.delete(modelPMMLFile);

    Config config = ConfigUtils.getDefaultConfig();
    String instanceDir = config.getString("model.instance-dir");

    String generationPrefix =
        Namespaces.getInstanceGenerationPrefix(instanceDir, mostRecentModelGeneration);
    String modelPMMLKey = generationPrefix + "model.pmml.gz";
    Store.get().download(modelPMMLKey, modelPMMLFile);
    log.info("Loading model description from {}", modelPMMLKey);

    Pair<DecisionForest, Map<Integer, BiMap<String, Integer>>> forestAndCatalog =
        DecisionForestPMML.read(modelPMMLFile);
    IOUtils.delete(modelPMMLFile);
    log.info("Loaded model description");

    modelGeneration = mostRecentModelGeneration;
    currentModel = new Generation(forestAndCatalog.getFirst(), forestAndCatalog.getSecond());
  }
예제 #2
0
  @Override
  protected MRPipeline createPipeline() throws IOException {
    JobStepConfig stepConfig = getConfig();
    ClusterSettings settings = ClusterSettings.create(ConfigUtils.getDefaultConfig());

    String instanceDir = stepConfig.getInstanceDir();
    int generationID = stepConfig.getGenerationID();
    int iteration = stepConfig.getIteration();
    String prefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    String outputKey = prefix + String.format("sketch/%d/", iteration);
    if (!validOutputPath(outputKey)) {
      return null;
    }

    // get normalized vectors
    String inputKey = prefix + "normalized/";
    MRPipeline p = createBasicPipeline(DistanceToClosestFn.class);
    AvroType<Pair<Integer, RealVector>> inputType = Avros.pairs(Avros.ints(), MLAvros.vector());
    PCollection<Pair<Integer, RealVector>> in = p.read(avroInput(inputKey, inputType));

    // either create or load the set of currently chosen k-sketch vectors
    // they are stored in a KSketchIndex object
    DistanceToClosestFn<RealVector> distanceToClosestFn;
    UpdateIndexFn updateIndexFn;
    if (iteration
        == 1) { // Iteration 1 is the first real iteration; iteration 0 contains initial state
      KSketchIndex index = createInitialIndex(settings, in);
      distanceToClosestFn = new DistanceToClosestFn<>(index);
      updateIndexFn = new UpdateIndexFn(index);
    } else {
      // Get the index location from the previous iteration
      String previousIndexKey = prefix + String.format("sketch/%d/", iteration - 1);
      distanceToClosestFn = new DistanceToClosestFn<>(previousIndexKey);
      updateIndexFn = new UpdateIndexFn(previousIndexKey);
    }

    // compute distance of each vector in dataset to closest vector in k-sketch
    PTable<Integer, Pair<RealVector, Double>> weighted =
        in.parallelDo(
            "computeDistances",
            distanceToClosestFn,
            Avros.tableOf(Avros.ints(), Avros.pairs(MLAvros.vector(), Avros.doubles())));

    // run weighted reservoir sampling on the vector to select another group of
    // settings.getSketchPoints()
    // to add to the k-sketch
    PTable<Integer, RealVector> kSketchSample =
        ReservoirSampling.groupedWeightedSample(
            weighted, settings.getSketchPoints(), RandomManager.getRandom());

    // update the KSketchIndex with the newly-chosen vectors
    kSketchSample
        .parallelDo("updateIndex", updateIndexFn, Serializables.avro(KSketchIndex.class))
        .write(avroOutput(outputKey));

    return p;
  }
  @Override
  protected void doPost() throws IOException {

    String instanceGenerationPrefix =
        Namespaces.getInstanceGenerationPrefix(getInstanceDir(), getGenerationID());
    String outputPathKey = instanceGenerationPrefix + "trees/";
    Store store = Store.get();
    PMML joinedForest = null;

    for (String treePrefix : store.list(outputPathKey, true)) {

      File treeTempFile = File.createTempFile("model-", ".pmml.gz");
      treeTempFile.deleteOnExit();
      store.download(treePrefix, treeTempFile);

      PMML pmml;
      InputStream in = IOUtils.openMaybeDecompressing(treeTempFile);
      try {
        pmml = IOUtil.unmarshal(in);
      } catch (SAXException e) {
        throw new IOException(e);
      } catch (JAXBException e) {
        throw new IOException(e);
      } finally {
        in.close();
      }

      IOUtils.delete(treeTempFile);

      if (joinedForest == null) {
        joinedForest = pmml;
      } else {
        MiningModel existingModel = (MiningModel) joinedForest.getModels().get(0);
        MiningModel nextModel = (MiningModel) pmml.getModels().get(0);
        existingModel
            .getSegmentation()
            .getSegments()
            .addAll(nextModel.getSegmentation().getSegments());
      }
    }

    File tempJoinedForestFile = File.createTempFile("model-", ".pmml.gz");
    tempJoinedForestFile.deleteOnExit();
    OutputStream out = IOUtils.buildGZIPOutputStream(new FileOutputStream(tempJoinedForestFile));
    try {
      IOUtil.marshal(joinedForest, out);
    } catch (JAXBException e) {
      throw new IOException(e);
    } finally {
      out.close();
    }

    store.upload(instanceGenerationPrefix + "model.pmml.gz", tempJoinedForestFile, false);
    IOUtils.delete(tempJoinedForestFile);
  }
예제 #4
0
  @Override
  protected void runSteps() throws IOException {

    String instanceDir = getInstanceDir();
    int generationID = getGenerationID();
    String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    int lastGenerationID = getLastGenerationID();

    File currentInputDir = Files.createTempDir();
    currentInputDir.deleteOnExit();
    File tempOutDir = Files.createTempDir();
    tempOutDir.deleteOnExit();

    try {

      Store store = Store.get();
      store.downloadDirectory(generationPrefix + "inbound/", currentInputDir);
      if (lastGenerationID >= 0) {
        store.downloadDirectory(
            Namespaces.getInstanceGenerationPrefix(instanceDir, lastGenerationID) + "input/",
            currentInputDir);
      }

      List<Example> examples = new ArrayList<>();
      Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping = new HashMap<>();
      new ReadInputs(currentInputDir, examples, columnToCategoryNameToIDMapping).call();

      DecisionForest forest = DecisionForest.fromExamplesWithDefault(examples);

      DecisionForestPMML.write(
          new File(tempOutDir, "model.pmml.gz"), forest, columnToCategoryNameToIDMapping);

      store.uploadDirectory(generationPrefix + "input/", currentInputDir, false);
      store.uploadDirectory(generationPrefix, tempOutDir, false);

    } finally {
      IOUtils.deleteRecursively(currentInputDir);
      IOUtils.deleteRecursively(tempOutDir);
    }
  }
예제 #5
0
  @Override
  protected MRPipeline createPipeline() throws IOException {
    JobStepConfig stepConfig = getConfig();
    Config config = ConfigUtils.getDefaultConfig();
    ClusterSettings clusterSettings = ClusterSettings.create(config);

    String instanceDir = stepConfig.getInstanceDir();
    long generationID = stepConfig.getGenerationID();
    String prefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    String outputKey = prefix + "weighted/";
    if (!validOutputPath(outputKey)) {
      return null;
    }

    String indexKey = prefix + "sketch/" + clusterSettings.getSketchIterations();
    String inputKey = prefix + "normalized/";
    MRPipeline p = createBasicPipeline(ClosestSketchVectorFn.class);

    // first I compute the weight of each k-sketch vector, i.e., Voronoi partition
    // I aggregate all together and persist on disk
    // PCollection<ClosestSketchVectorData> weights = inputPairs(p, inputKey, MLAvros.vector())
    PCollection<ClosestSketchVectorData> weights =
        PTables.asPTable(
                inputPairs(p, inputKey, MLAvros.vector())
                    .parallelDo(
                        "computingSketchVectorWeights",
                        new ClosestSketchVectorFn<RealVector>(indexKey, clusterSettings),
                        Avros.pairs(Avros.ints(), Avros.reflects(ClosestSketchVectorData.class))))
            .groupByKey(1)
            .combineValues(new ClosestSketchVectorAggregator(clusterSettings))
            .values()
            .write(avroOutput(outputKey + "kSketchVectorWeights/"));

    // this "pipeline" takes a single ClosestSketchVectorData and returns weighted vectors
    // could be done outside MapReduce, but that would require me to materialize the
    // ClosestSketchVectorData
    weights
        .parallelDo(
            "generatingWeightedSketchVectors",
            new WeightVectorsFn(indexKey),
            KMeansTypes.FOLD_WEIGHTED_VECTOR)
        .write(avroOutput(outputKey + "weightedKSketchVectors/"));

    return p;
  }
예제 #6
0
  void loadModel(int generationID, Generation currentGeneration) throws IOException {

    File modelPMMLFile = File.createTempFile("oryx-model", ".pmml.gz");
    modelPMMLFile.deleteOnExit();
    IOUtils.delete(modelPMMLFile);

    String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    String modelPMMLKey = generationPrefix + "model.pmml.gz";
    Store.get().download(modelPMMLKey, modelPMMLFile);
    log.info("Loading model description from {}", modelPMMLKey);

    ALSModelDescription modelDescription = ALSModelDescription.read(modelPMMLFile);
    IOUtils.delete(modelPMMLFile);

    Collection<Future<Object>> futures = new ArrayList<>();
    // Limit this fairly sharply to 2 so as to not saturate the network link
    ExecutorService executor = ExecutorUtils.buildExecutor("LoadModel", 2);

    LongSet loadedUserIDs;
    LongSet loadedItemIDs;
    LongSet loadedUserIDsForKnownItems;
    try {
      loadedUserIDs =
          loadXOrY(generationPrefix, modelDescription, true, currentGeneration, futures, executor);
      loadedItemIDs =
          loadXOrY(generationPrefix, modelDescription, false, currentGeneration, futures, executor);

      if (currentGeneration.getKnownItemIDs() == null) {
        loadedUserIDsForKnownItems = null;
      } else {
        loadedUserIDsForKnownItems =
            loadKnownItemIDs(
                generationPrefix, modelDescription, currentGeneration, futures, executor);
      }

      loadIDMapping(generationPrefix, modelDescription, currentGeneration, futures, executor);

      ExecutorUtils.getResults(futures);
      log.info("Finished all load tasks");

    } finally {
      ExecutorUtils.shutdownNowAndAwait(executor);
    }

    log.info("Pruning old entries...");
    synchronized (lockForRecent) {
      removeNotUpdated(
          currentGeneration.getX().keySetIterator(),
          loadedUserIDs,
          recentlyActiveUsers,
          currentGeneration.getXLock().writeLock());
      removeNotUpdated(
          currentGeneration.getY().keySetIterator(),
          loadedItemIDs,
          recentlyActiveItems,
          currentGeneration.getYLock().writeLock());
      if (loadedUserIDsForKnownItems != null && currentGeneration.getKnownItemIDs() != null) {
        removeNotUpdated(
            currentGeneration.getKnownItemIDs().keySetIterator(),
            loadedUserIDsForKnownItems,
            recentlyActiveUsers,
            currentGeneration.getKnownItemLock().writeLock());
      }
      this.recentlyActiveItems.clear();
      this.recentlyActiveUsers.clear();
    }

    log.info("Recomputing generation state...");
    currentGeneration.recomputeState();

    log.info(
        "All model elements loaded, {} users and {} items",
        currentGeneration.getNumUsers(),
        currentGeneration.getNumItems());
  }