@Override
  public Object preProcess(DocumentSet training, Object transientData) {
    inverseIndex = new HashMap<>();
    double[] expectedWeightPerLabelHint = new double[2];
    System.out.println("Preparing clustering");

    // class counts
    training.forEach(
        (docId, doc) -> {
          int cls = doc.getClasses().contains(className) ? 1 : 0;
          expectedWeightPerLabelHint[cls] += 1;
        });

    Function<FeatureNode, Double> estimateInfoGain =
        (f) -> {
          InfoGainCalc calc =
              new InfoGainCalc(2, false, f.featureName, InfoGainCalc.EstimationEnum.INFO_GAIN);
          f.values.forEachEntry(
              (id, v) -> {
                calc.addSample(
                    training.document(id).getClasses().contains(className) ? 1 : 0,
                    v / InfoGainCalc.PRECISION_INV);
                return true;
              });
          calc.setExpectedWeightPerLabelHint(expectedWeightPerLabelHint);
          return calc.estimateRelevance();
        };

    Function<FeatureNode, Double> evaluator =
        (f) -> {
          return f.eval;
        };

    BiFunction<FeatureNode, FeatureNode, FeatureNode> combiner =
        (f1, f2) -> {
          FeatureNode combined = new FeatureNode(null);

          f1.values.forEachEntry(
              (id, v) -> {
                combined.values.put(id, v);
                return true;
              });
          f2.values.forEachEntry(
              (id, v) -> {
                long o = combined.values.get(id);
                if (isMaxCombiner) {
                  combined.values.put(id, v > o ? v : o);
                } else { // if isSumCombiner
                  combined.values.put(id, o + v);
                }
                return true;
              });

          combined.eval = estimateInfoGain.apply(combined);
          return combined;
        };

    // calculate feature-doc map
    Map<String, TLongLongHashMap> fmap = new HashMap<>();
    training.forEach(
        (docId, doc) -> {
          doc.getFeatureSet(sourceFeatureSet)
              .forEach(
                  (f) -> {
                    TLongLongHashMap fmapt = fmap.get(f.getName());
                    if (fmapt == null) {
                      fmapt = new TLongLongHashMap();
                      fmap.put(f.getName(), fmapt);
                    }
                    fmapt.put(
                        docId, (long) (((Double) f.doubleValue()) * InfoGainCalc.PRECISION_INV));
                  });
        });

    ArrayList<FeatureNode> arrTemp = new ArrayList<>();
    // convert to featureNodes
    fmap.forEach(
        (fname, arr) -> {
          FeatureNode node = new FeatureNode(fname, arr);
          node.eval = estimateInfoGain.apply(node);
          arrTemp.add(node);
        });

    // optimization: get only best info gain features
    Collections.sort(arrTemp, (n1, n2) -> -Double.compare(n1.eval, n2.eval));
    for (int i = 0; i < arrTemp.size() && i < nOfBestToUse; i++) {
      inverseIndex.put(arrTemp.get(i).featureName, arrTemp.get(i));
    }

    System.out.println("Doing clustering");

    AgglomerativeSampling<FeatureNode> clustering =
        new AgglomerativeSampling<>(evaluator, combiner, inverseIndex.values());

    clustering.setMaxSamples(nOfBestToUse);
    clustering.doClustering(nClusters);

    // collect clusters
    clusters = new HashMap<>();
    clustering.forEachCluster(
        (c) -> {
          Set<String> features = new HashSet<>();
          c.forEachLeaf((l) -> features.add(l.getPoint().featureName));
          if (features.size() >= 1) {
            clusters.put(clusters.size(), features);
          }
        });

    // release memory
    inverseIndex = null;
    return null;
  }
 @Override
 public void reset(DocumentSet docs, String className) {
   this.className = className;
   docs.forEach((id, doc) -> doc.removeFeatureSet(FEATURE_SET));
 }