/** * output the top level labels for each tree * * @param trees the trees to predict * @return the prediction labels for each tree */ public List<Integer> predict(List<Tree> trees) { List<Integer> ret = new ArrayList<>(); for (Tree t : trees) { forwardPropagateTree(t); ret.add(SimpleBlas.iamax(t.prediction())); } return ret; }
private int getLabel(FloatDataSet data) { return SimpleBlas.iamax(data.getSecond()); }
public int outcome() { if (this.numExamples() > 1) throw new IllegalStateException("Unable to derive outcome for dataset greater than one row"); return SimpleBlas.iamax(getSecond()); }