예제 #1
0
  @Override
  protected void loadRecentModel(long mostRecentModelGeneration) throws IOException {
    if (mostRecentModelGeneration <= modelGeneration) {
      return;
    }
    if (modelGeneration == NO_GENERATION) {
      log.info("Most recent generation {} is the first available one", mostRecentModelGeneration);
    } else {
      log.info(
          "Most recent generation {} is newer than current {}",
          mostRecentModelGeneration,
          modelGeneration);
    }

    File modelPMMLFile = File.createTempFile("model-", ".pmml.gz");
    modelPMMLFile.deleteOnExit();
    IOUtils.delete(modelPMMLFile);

    Config config = ConfigUtils.getDefaultConfig();
    String instanceDir = config.getString("model.instance-dir");

    String generationPrefix =
        Namespaces.getInstanceGenerationPrefix(instanceDir, mostRecentModelGeneration);
    String modelPMMLKey = generationPrefix + "model.pmml.gz";
    Store.get().download(modelPMMLKey, modelPMMLFile);
    log.info("Loading model description from {}", modelPMMLKey);

    Pair<DecisionForest, Map<Integer, BiMap<String, Integer>>> forestAndCatalog =
        DecisionForestPMML.read(modelPMMLFile);
    IOUtils.delete(modelPMMLFile);
    log.info("Loaded model description");

    modelGeneration = mostRecentModelGeneration;
    currentModel = new Generation(forestAndCatalog.getFirst(), forestAndCatalog.getSecond());
  }
  @Override
  protected void doPost() throws IOException {

    String instanceGenerationPrefix =
        Namespaces.getInstanceGenerationPrefix(getInstanceDir(), getGenerationID());
    String outputPathKey = instanceGenerationPrefix + "trees/";
    Store store = Store.get();
    PMML joinedForest = null;

    for (String treePrefix : store.list(outputPathKey, true)) {

      File treeTempFile = File.createTempFile("model-", ".pmml.gz");
      treeTempFile.deleteOnExit();
      store.download(treePrefix, treeTempFile);

      PMML pmml;
      InputStream in = IOUtils.openMaybeDecompressing(treeTempFile);
      try {
        pmml = IOUtil.unmarshal(in);
      } catch (SAXException e) {
        throw new IOException(e);
      } catch (JAXBException e) {
        throw new IOException(e);
      } finally {
        in.close();
      }

      IOUtils.delete(treeTempFile);

      if (joinedForest == null) {
        joinedForest = pmml;
      } else {
        MiningModel existingModel = (MiningModel) joinedForest.getModels().get(0);
        MiningModel nextModel = (MiningModel) pmml.getModels().get(0);
        existingModel
            .getSegmentation()
            .getSegments()
            .addAll(nextModel.getSegmentation().getSegments());
      }
    }

    File tempJoinedForestFile = File.createTempFile("model-", ".pmml.gz");
    tempJoinedForestFile.deleteOnExit();
    OutputStream out = IOUtils.buildGZIPOutputStream(new FileOutputStream(tempJoinedForestFile));
    try {
      IOUtil.marshal(joinedForest, out);
    } catch (JAXBException e) {
      throw new IOException(e);
    } finally {
      out.close();
    }

    store.upload(instanceGenerationPrefix + "model.pmml.gz", tempJoinedForestFile, false);
    IOUtils.delete(tempJoinedForestFile);
  }
예제 #3
0
  @Override
  protected void runSteps() throws IOException {

    String instanceDir = getInstanceDir();
    int generationID = getGenerationID();
    String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    int lastGenerationID = getLastGenerationID();

    File currentInputDir = Files.createTempDir();
    currentInputDir.deleteOnExit();
    File tempOutDir = Files.createTempDir();
    tempOutDir.deleteOnExit();

    try {

      Store store = Store.get();
      store.downloadDirectory(generationPrefix + "inbound/", currentInputDir);
      if (lastGenerationID >= 0) {
        store.downloadDirectory(
            Namespaces.getInstanceGenerationPrefix(instanceDir, lastGenerationID) + "input/",
            currentInputDir);
      }

      List<Example> examples = new ArrayList<>();
      Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping = new HashMap<>();
      new ReadInputs(currentInputDir, examples, columnToCategoryNameToIDMapping).call();

      DecisionForest forest = DecisionForest.fromExamplesWithDefault(examples);

      DecisionForestPMML.write(
          new File(tempOutDir, "model.pmml.gz"), forest, columnToCategoryNameToIDMapping);

      store.uploadDirectory(generationPrefix + "input/", currentInputDir, false);
      store.uploadDirectory(generationPrefix, tempOutDir, false);

    } finally {
      IOUtils.deleteRecursively(currentInputDir);
      IOUtils.deleteRecursively(tempOutDir);
    }
  }
예제 #4
0
  void loadModel(int generationID, Generation currentGeneration) throws IOException {

    File modelPMMLFile = File.createTempFile("oryx-model", ".pmml.gz");
    modelPMMLFile.deleteOnExit();
    IOUtils.delete(modelPMMLFile);

    String generationPrefix = Namespaces.getInstanceGenerationPrefix(instanceDir, generationID);
    String modelPMMLKey = generationPrefix + "model.pmml.gz";
    Store.get().download(modelPMMLKey, modelPMMLFile);
    log.info("Loading model description from {}", modelPMMLKey);

    ALSModelDescription modelDescription = ALSModelDescription.read(modelPMMLFile);
    IOUtils.delete(modelPMMLFile);

    Collection<Future<Object>> futures = new ArrayList<>();
    // Limit this fairly sharply to 2 so as to not saturate the network link
    ExecutorService executor = ExecutorUtils.buildExecutor("LoadModel", 2);

    LongSet loadedUserIDs;
    LongSet loadedItemIDs;
    LongSet loadedUserIDsForKnownItems;
    try {
      loadedUserIDs =
          loadXOrY(generationPrefix, modelDescription, true, currentGeneration, futures, executor);
      loadedItemIDs =
          loadXOrY(generationPrefix, modelDescription, false, currentGeneration, futures, executor);

      if (currentGeneration.getKnownItemIDs() == null) {
        loadedUserIDsForKnownItems = null;
      } else {
        loadedUserIDsForKnownItems =
            loadKnownItemIDs(
                generationPrefix, modelDescription, currentGeneration, futures, executor);
      }

      loadIDMapping(generationPrefix, modelDescription, currentGeneration, futures, executor);

      ExecutorUtils.getResults(futures);
      log.info("Finished all load tasks");

    } finally {
      ExecutorUtils.shutdownNowAndAwait(executor);
    }

    log.info("Pruning old entries...");
    synchronized (lockForRecent) {
      removeNotUpdated(
          currentGeneration.getX().keySetIterator(),
          loadedUserIDs,
          recentlyActiveUsers,
          currentGeneration.getXLock().writeLock());
      removeNotUpdated(
          currentGeneration.getY().keySetIterator(),
          loadedItemIDs,
          recentlyActiveItems,
          currentGeneration.getYLock().writeLock());
      if (loadedUserIDsForKnownItems != null && currentGeneration.getKnownItemIDs() != null) {
        removeNotUpdated(
            currentGeneration.getKnownItemIDs().keySetIterator(),
            loadedUserIDsForKnownItems,
            recentlyActiveUsers,
            currentGeneration.getKnownItemLock().writeLock());
      }
      this.recentlyActiveItems.clear();
      this.recentlyActiveUsers.clear();
    }

    log.info("Recomputing generation state...");
    currentGeneration.recomputeState();

    log.info(
        "All model elements loaded, {} users and {} items",
        currentGeneration.getNumUsers(),
        currentGeneration.getNumItems());
  }