@Override
  protected void loadRecentModel(long mostRecentModelGeneration) throws IOException {
    if (mostRecentModelGeneration <= modelGeneration) {
      return;
    }
    if (modelGeneration == NO_GENERATION) {
      log.info("Most recent generation {} is the first available one", mostRecentModelGeneration);
    } else {
      log.info(
          "Most recent generation {} is newer than current {}",
          mostRecentModelGeneration,
          modelGeneration);
    }

    File modelPMMLFile = File.createTempFile("model-", ".pmml.gz");
    modelPMMLFile.deleteOnExit();
    IOUtils.delete(modelPMMLFile);

    Config config = ConfigUtils.getDefaultConfig();
    String instanceDir = config.getString("model.instance-dir");

    String generationPrefix =
        Namespaces.getInstanceGenerationPrefix(instanceDir, mostRecentModelGeneration);
    String modelPMMLKey = generationPrefix + "model.pmml.gz";
    Store.get().download(modelPMMLKey, modelPMMLFile);
    log.info("Loading model description from {}", modelPMMLKey);

    Pair<DecisionForest, Map<Integer, BiMap<String, Integer>>> forestAndCatalog =
        DecisionForestPMML.read(modelPMMLFile);
    IOUtils.delete(modelPMMLFile);
    log.info("Loaded model description");

    modelGeneration = mostRecentModelGeneration;
    currentModel = new Generation(forestAndCatalog.getFirst(), forestAndCatalog.getSecond());
  }
Exemple #2
0
 private static void loadIDMapping(
     String generationPrefix,
     ALSModelDescription modelDescription,
     final Generation generation,
     Collection<Future<Object>> futures,
     ExecutorService executor)
     throws IOException {
   String idMappingPrefix = generationPrefix + modelDescription.getIDMappingPath();
   for (final String prefix : Store.get().list(idMappingPrefix, true)) {
     futures.add(
         executor.submit(
             new Callable<Object>() {
               @Override
               public Void call() throws IOException {
                 for (CharSequence line : new FileLineIterable(Store.get().readFrom(prefix))) {
                   String[] columns = DelimitedDataUtils.decode(line, ',');
                   long numericID = Long.parseLong(columns[0]);
                   String id = columns[1];
                   Lock writeLock = generation.getKnownItemLock().writeLock();
                   StringLongMapping idMapping = generation.getIDMapping();
                   writeLock.lock();
                   try {
                     idMapping.addMapping(id, numericID);
                   } finally {
                     writeLock.unlock();
                   }
                 }
                 return null;
               }
             }));
   }
 }
  @Override
  protected void doPost() throws IOException {

    String instanceGenerationPrefix =
        Namespaces.getInstanceGenerationPrefix(getInstanceDir(), getGenerationID());
    String outputPathKey = instanceGenerationPrefix + "trees/";
    Store store = Store.get();
    PMML joinedForest = null;

    for (String treePrefix : store.list(outputPathKey, true)) {

      File treeTempFile = File.createTempFile("model-", ".pmml.gz");
      treeTempFile.deleteOnExit();
      store.download(treePrefix, treeTempFile);

      PMML pmml;
      InputStream in = IOUtils.openMaybeDecompressing(treeTempFile);
      try {
        pmml = IOUtil.unmarshal(in);
      } catch (SAXException e) {
        throw new IOException(e);
      } catch (JAXBException e) {
        throw new IOException(e);
      } finally {
        in.close();
      }

      IOUtils.delete(treeTempFile);

      if (joinedForest == null) {
        joinedForest = pmml;
      } else {
        MiningModel existingModel = (MiningModel) joinedForest.getModels().get(0);
        MiningModel nextModel = (MiningModel) pmml.getModels().get(0);
        existingModel
            .getSegmentation()
            .getSegments()
            .addAll(nextModel.getSegmentation().getSegments());
      }
    }

    File tempJoinedForestFile = File.createTempFile("model-", ".pmml.gz");
    tempJoinedForestFile.deleteOnExit();
    OutputStream out = IOUtils.buildGZIPOutputStream(new FileOutputStream(tempJoinedForestFile));
    try {
      IOUtil.marshal(joinedForest, out);
    } catch (JAXBException e) {
      throw new IOException(e);
    } finally {
      out.close();
    }

    store.upload(instanceGenerationPrefix + "model.pmml.gz", tempJoinedForestFile, false);
    IOUtils.delete(tempJoinedForestFile);
  }
Exemple #4
0
  private static LongSet loadXOrY(
      String generationPrefix,
      ALSModelDescription modelDescription,
      boolean isX,
      Generation generation,
      Collection<Future<Object>> futures,
      ExecutorService executor)
      throws IOException {

    String xOrYPrefix =
        generationPrefix + (isX ? modelDescription.getXPath() : modelDescription.getYPath());
    final LongSet loadedIDs = new LongSet();

    final Lock writeLock =
        isX ? generation.getXLock().writeLock() : generation.getYLock().writeLock();
    final LongObjectMap<float[]> xOrYMatrix = isX ? generation.getX() : generation.getY();

    for (final String xOrYFilePrefix : Store.get().list(xOrYPrefix, true)) {
      futures.add(
          executor.submit(
              new Callable<Object>() {
                @Override
                public Void call() throws IOException {
                  for (String line : new FileLineIterable(Store.get().readFrom(xOrYFilePrefix))) {

                    int tab = line.indexOf('\t');
                    Preconditions.checkArgument(
                        tab >= 0, "Bad input line in %s: %s", xOrYFilePrefix, line);
                    long id = Long.parseLong(line.substring(0, tab));

                    float[] elements = DataUtils.readFeatureVector(line.substring(tab + 1));

                    writeLock.lock();
                    try {
                      xOrYMatrix.put(id, elements);
                      loadedIDs.add(id);
                    } finally {
                      writeLock.unlock();
                    }
                  }
                  log.info("Loaded feature vectors from {}", xOrYFilePrefix);
                  return null;
                }
              }));
    }

    return loadedIDs;
  }
  @Override
  protected void runSteps() throws IOException {

    String instanceDir = getInstanceDir();
    int generationID = getGenerationID();
    String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    int lastGenerationID = getLastGenerationID();

    File currentInputDir = Files.createTempDir();
    currentInputDir.deleteOnExit();
    File tempOutDir = Files.createTempDir();
    tempOutDir.deleteOnExit();

    try {

      Store store = Store.get();
      store.downloadDirectory(generationPrefix + "inbound/", currentInputDir);
      if (lastGenerationID >= 0) {
        store.downloadDirectory(
            Namespaces.getInstanceGenerationPrefix(instanceDir, lastGenerationID) + "input/",
            currentInputDir);
      }

      List<Example> examples = new ArrayList<>();
      Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping = new HashMap<>();
      new ReadInputs(currentInputDir, examples, columnToCategoryNameToIDMapping).call();

      DecisionForest forest = DecisionForest.fromExamplesWithDefault(examples);

      DecisionForestPMML.write(
          new File(tempOutDir, "model.pmml.gz"), forest, columnToCategoryNameToIDMapping);

      store.uploadDirectory(generationPrefix + "input/", currentInputDir, false);
      store.uploadDirectory(generationPrefix, tempOutDir, false);

    } finally {
      IOUtils.deleteRecursively(currentInputDir);
      IOUtils.deleteRecursively(tempOutDir);
    }
  }
Exemple #6
0
 private static LongSet loadKnownItemIDs(
     String generationPrefix,
     ALSModelDescription modelDescription,
     final Generation generation,
     Collection<Future<Object>> futures,
     ExecutorService executor)
     throws IOException {
   final LongSet loadedIDs = new LongSet();
   String knownItemsPrefix = generationPrefix + modelDescription.getKnownItemsPath();
   for (final String knownItemFilePrefix : Store.get().list(knownItemsPrefix, true)) {
     futures.add(
         executor.submit(
             new Callable<Object>() {
               @Override
               public Void call() throws IOException {
                 for (String line :
                     new FileLineIterable(Store.get().readFrom(knownItemFilePrefix))) {
                   int tab = line.indexOf('\t');
                   Preconditions.checkArgument(
                       tab >= 0, "Bad input line in %s: %s", knownItemFilePrefix, line);
                   long userID = Long.parseLong(line.substring(0, tab));
                   LongSet itemIDs = stringToSet(line.substring(tab + 1));
                   Lock writeLock = generation.getKnownItemLock().writeLock();
                   LongObjectMap<LongSet> knownItems = generation.getKnownItemIDs();
                   writeLock.lock();
                   try {
                     knownItems.put(userID, itemIDs);
                     loadedIDs.add(userID);
                   } finally {
                     writeLock.unlock();
                   }
                 }
                 log.info("Loaded known items from {}", knownItemFilePrefix);
                 return null;
               }
             }));
   }
   return loadedIDs;
 }
Exemple #7
0
  void loadModel(int generationID, Generation currentGeneration) throws IOException {

    File modelPMMLFile = File.createTempFile("oryx-model", ".pmml.gz");
    modelPMMLFile.deleteOnExit();
    IOUtils.delete(modelPMMLFile);

    String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    String modelPMMLKey = generationPrefix + "model.pmml.gz";
    Store.get().download(modelPMMLKey, modelPMMLFile);
    log.info("Loading model description from {}", modelPMMLKey);

    ALSModelDescription modelDescription = ALSModelDescription.read(modelPMMLFile);
    IOUtils.delete(modelPMMLFile);

    Collection<Future<Object>> futures = new ArrayList<>();
    // Limit this fairly sharply to 2 so as to not saturate the network link
    ExecutorService executor = ExecutorUtils.buildExecutor("LoadModel", 2);

    LongSet loadedUserIDs;
    LongSet loadedItemIDs;
    LongSet loadedUserIDsForKnownItems;
    try {
      loadedUserIDs =
          loadXOrY(generationPrefix, modelDescription, true, currentGeneration, futures, executor);
      loadedItemIDs =
          loadXOrY(generationPrefix, modelDescription, false, currentGeneration, futures, executor);

      if (currentGeneration.getKnownItemIDs() == null) {
        loadedUserIDsForKnownItems = null;
      } else {
        loadedUserIDsForKnownItems =
            loadKnownItemIDs(
                generationPrefix, modelDescription, currentGeneration, futures, executor);
      }

      loadIDMapping(generationPrefix, modelDescription, currentGeneration, futures, executor);

      ExecutorUtils.getResults(futures);
      log.info("Finished all load tasks");

    } finally {
      ExecutorUtils.shutdownNowAndAwait(executor);
    }

    log.info("Pruning old entries...");
    synchronized (lockForRecent) {
      removeNotUpdated(
          currentGeneration.getX().keySetIterator(),
          loadedUserIDs,
          recentlyActiveUsers,
          currentGeneration.getXLock().writeLock());
      removeNotUpdated(
          currentGeneration.getY().keySetIterator(),
          loadedItemIDs,
          recentlyActiveItems,
          currentGeneration.getYLock().writeLock());
      if (loadedUserIDsForKnownItems != null && currentGeneration.getKnownItemIDs() != null) {
        removeNotUpdated(
            currentGeneration.getKnownItemIDs().keySetIterator(),
            loadedUserIDsForKnownItems,
            recentlyActiveUsers,
            currentGeneration.getKnownItemLock().writeLock());
      }
      this.recentlyActiveItems.clear();
      this.recentlyActiveUsers.clear();
    }

    log.info("Recomputing generation state...");
    currentGeneration.recomputeState();

    log.info(
        "All model elements loaded, {} users and {} items",
        currentGeneration.getNumUsers(),
        currentGeneration.getNumItems());
  }