@Override public Object preProcess(DocumentSet training, Object transientData) { inverseIndex = new HashMap<>(); double[] expectedWeightPerLabelHint = new double[2]; System.out.println("Preparing clustering"); // class counts training.forEach( (docId, doc) -> { int cls = doc.getClasses().contains(className) ? 1 : 0; expectedWeightPerLabelHint[cls] += 1; }); Function<FeatureNode, Double> estimateInfoGain = (f) -> { InfoGainCalc calc = new InfoGainCalc(2, false, f.featureName, InfoGainCalc.EstimationEnum.INFO_GAIN); f.values.forEachEntry( (id, v) -> { calc.addSample( training.document(id).getClasses().contains(className) ? 1 : 0, v / InfoGainCalc.PRECISION_INV); return true; }); calc.setExpectedWeightPerLabelHint(expectedWeightPerLabelHint); return calc.estimateRelevance(); }; Function<FeatureNode, Double> evaluator = (f) -> { return f.eval; }; BiFunction<FeatureNode, FeatureNode, FeatureNode> combiner = (f1, f2) -> { FeatureNode combined = new FeatureNode(null); f1.values.forEachEntry( (id, v) -> { combined.values.put(id, v); return true; }); f2.values.forEachEntry( (id, v) -> { long o = combined.values.get(id); if (isMaxCombiner) { combined.values.put(id, v > o ? v : o); } else { // if isSumCombiner combined.values.put(id, o + v); } return true; }); combined.eval = estimateInfoGain.apply(combined); return combined; }; // calculate feature-doc map Map<String, TLongLongHashMap> fmap = new HashMap<>(); training.forEach( (docId, doc) -> { doc.getFeatureSet(sourceFeatureSet) .forEach( (f) -> { TLongLongHashMap fmapt = fmap.get(f.getName()); if (fmapt == null) { fmapt = new TLongLongHashMap(); fmap.put(f.getName(), fmapt); } fmapt.put( docId, (long) (((Double) f.doubleValue()) * InfoGainCalc.PRECISION_INV)); }); }); ArrayList<FeatureNode> arrTemp = new ArrayList<>(); // convert to featureNodes fmap.forEach( (fname, arr) -> { FeatureNode node = new FeatureNode(fname, arr); node.eval = estimateInfoGain.apply(node); arrTemp.add(node); }); // optimization: get only best info gain features Collections.sort(arrTemp, (n1, n2) -> -Double.compare(n1.eval, n2.eval)); for (int i = 0; i < arrTemp.size() && i < nOfBestToUse; i++) { inverseIndex.put(arrTemp.get(i).featureName, arrTemp.get(i)); } System.out.println("Doing clustering"); AgglomerativeSampling<FeatureNode> clustering = new AgglomerativeSampling<>(evaluator, combiner, inverseIndex.values()); clustering.setMaxSamples(nOfBestToUse); clustering.doClustering(nClusters); // collect clusters clusters = new HashMap<>(); clustering.forEachCluster( (c) -> { Set<String> features = new HashSet<>(); c.forEachLeaf((l) -> features.add(l.getPoint().featureName)); if (features.size() >= 1) { clusters.put(clusters.size(), features); } }); // release memory inverseIndex = null; return null; }
@Override public void reset(DocumentSet docs, String className) { this.className = className; docs.forEach((id, doc) -> doc.removeFeatureSet(FEATURE_SET)); }