// TODO save dan load @Override public void writeHypothesis(OutputStream str) { PrintStream printStream = new PrintStream(str); printStream.println(k); // save tipe dari classifier printStream.println(classifier.getClass().getName()); printStream.flush(); classifier.writeHypothesis(str); }
@Override public void train( int[] numInputCategory, int[][] inputCategory, int numOutputClass, int[] outputClass) throws Exception { shuffleTrainingSet(inputCategory, outputClass); OfflineLearningNominalDataClassifier[] candidate = new OfflineLearningNominalDataClassifier[k]; double[] accuracy = new double[k]; int selected = 0; for (int i = 0; i < k; i++) { candidate[i] = classifier.copy(); // pisah set menjadi tr dan cv ArrayList<int[]> trInputCategory = new ArrayList<int[]>(); ArrayList<Integer> trOutputClass = new ArrayList<Integer>(); ArrayList<int[]> cvInputCategory = new ArrayList<int[]>(); ArrayList<Integer> cvOutputClass = new ArrayList<Integer>(); for (int j = 0; j < i * inputCategory.length / k; j++) { trInputCategory.add(inputCategory[j]); trOutputClass.add(outputClass[j]); } for (int j = i * inputCategory.length / k; j < (i + 1) * inputCategory.length / k; j++) { cvInputCategory.add(inputCategory[j]); cvOutputClass.add(outputClass[j]); } for (int j = (i + 1) * inputCategory.length / k; j < inputCategory.length; j++) { trInputCategory.add(inputCategory[j]); trOutputClass.add(outputClass[j]); } candidate[i].train( numInputCategory, (int[][]) trInputCategory.toArray(new int[trInputCategory.size()][]), numOutputClass, trOutputClass.stream().mapToInt(Integer::intValue).toArray()); accuracy[i] = calculateAccuracy( candidate[i], cvInputCategory.toArray(new int[cvInputCategory.size()][]), cvOutputClass.stream().mapToInt(Integer::intValue).toArray()); if (accuracy[i] > accuracy[selected]) selected = i; } classifier = candidate[selected]; }
@Override public void loadHypothesis(Scanner sc) { k = sc.nextInt(); String className = sc.nextLine(); System.out.println(className); Class cl; try { cl = Class.forName(className); classifier = (OfflineLearningNominalDataClassifier) cl.newInstance(); } catch (ClassNotFoundException ex) { Logger.getLogger(kFold.class.getName()).log(Level.SEVERE, null, ex); } catch (InstantiationException ex) { Logger.getLogger(kFold.class.getName()).log(Level.SEVERE, null, ex); } catch (IllegalAccessException ex) { Logger.getLogger(kFold.class.getName()).log(Level.SEVERE, null, ex); } classifier.loadHypothesis(sc); }
@Override public int predict(int[] inputCategory) throws Exception { return classifier.predict(inputCategory); }
public kFold(int k, OfflineLearningNominalDataClassifier classifier) { this.k = k; this.classifier = classifier.copy(); }