Esempio n. 1
0
  boolean checkIfMovie(double[] a) {
    // boolean isMovie = false;
    double a1 = a[1];
    double a2 = a[2];
    double a3 = a[3];
    double a4 = a[4];
    double a5 = a[5];
    double a6 = a1 + a2 + a3 + a4 + a5;
    double a7 = a1 / a2;
    double a8 = a1 / a3;
    double a9 = a1 / a4;
    double a10 = a1 / a5;
    double a11 = a2 / a3;
    double a12 = a2 / a4;
    double a13 = a2 / a5;
    double a14 = a3 / a4;
    double a15 = a3 / a5;
    double a16 = a4 / a5;
    DecisionTree dt =
        new DecisionTree(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16);
    return dt.isMovie();

    /*Log.d(TAG, "Coeff Sum = " + a1 + a2 + a3 + a4 + a5);
    if(a1 + a2 + a3 + a4 + a5 > 0.5)
    	isMovie = true;
    return isMovie;*/
  }
  /**
   * This method builds a decision tree model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildDecisionTreeModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures,
      Map<Integer, Integer> categoricalFeatureInfo)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      DecisionTree decisionTree = new DecisionTree();
      DecisionTreeModel decisionTreeModel =
          decisionTree.train(
              trainingData,
              getNoOfClasses(mlModel),
              categoricalFeatureInfo,
              hyperParameters.get(MLConstants.IMPURITY),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          decisionTree.test(decisionTreeModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.DECISION_TREE.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building decision tree model: " + e.getMessage(), e);
    }
  }
Esempio n. 3
0
 public double probabilityOfTrue(Counter<String> features) {
   double probTrue = 0;
   for (DecisionTree tree : trees) {
     probTrue += tree.probabilityOfTrue(features);
   }
   return probTrue / trees.length;
 }
Esempio n. 4
0
  /**
   * Displays the tree, starting with the given root node.
   *
   * @param root the Node that is the root of the tree to be displayed
   * @param offset the String
   */
  public static void displayTree(Node root, String offset) {
    if (root.children.size() == 0) {
      DecisionTree.appendText("\n" + offset + "    THEN (" + root.label + ")  (Leaf node)");
      return;
    } else {
      Enumeration enum1 = root.children.elements();
      Enumeration enum2 = root.linkLabels.elements();

      DecisionTree.appendText("\n" + offset + "   " + root.label + " (Interior node)");
      while (enum1.hasMoreElements()) {
        DecisionTree.appendText("\n" + offset + "   IF (" + (String) enum2.nextElement() + ")");
        displayTree((Node) enum1.nextElement(), offset + "   ");
      }
    }
  }
  public static void main(String args[]) {
    DecisionTree decisionTree = makeOutlookTree();

    Map<String, String> case1 = new HashMap<String, String>();
    case1.put("Outlook", "Overcast");
    case1.put("Temperature", "Hot");
    case1.put("Humidity", "High");
    case1.put("Wind", "Strong");

    try {
      System.out.println(decisionTree.classify(case1));
    } catch (UnknownDecisionException e) {
      System.out.println("?");
    }
  }
  @Test
  public void testClassify() {

    // train AND function on decision tree
    DecisionTree tree = new DecisionTree();
    String[] header = {"x1", "x2", "answer"};

    SimpleDataSample data1 =
        newSimpleDataSample("answer", header, Boolean.TRUE, Boolean.TRUE, BooleanLabel.TRUE_LABEL);
    SimpleDataSample data2 =
        newSimpleDataSample(
            "answer", header, Boolean.TRUE, Boolean.FALSE, BooleanLabel.FALSE_LABEL);
    SimpleDataSample data3 =
        newSimpleDataSample(
            "answer", header, Boolean.FALSE, Boolean.TRUE, BooleanLabel.FALSE_LABEL);
    SimpleDataSample data4 =
        newSimpleDataSample(
            "answer", header, Boolean.FALSE, Boolean.FALSE, BooleanLabel.FALSE_LABEL);

    Feature feature1 = newFeature("x1", Boolean.TRUE);
    Feature feature2 = newFeature("x1", Boolean.FALSE);
    Feature feature3 = newFeature("x2", Boolean.TRUE);
    Feature feature4 = newFeature("x2", Boolean.FALSE);

    tree.train(
        Arrays.asList(data1, data2, data3, data4),
        Arrays.asList(feature1, feature2, feature3, feature4));

    // now check classify
    String[] classificationHeader = {"x1", "x2"};
    Assert.assertEquals(
        BooleanLabel.TRUE_LABEL,
        tree.classify(
            newClassificationDataSample(classificationHeader, Boolean.TRUE, Boolean.TRUE)));
    Assert.assertEquals(
        BooleanLabel.FALSE_LABEL,
        tree.classify(
            newClassificationDataSample(classificationHeader, Boolean.TRUE, Boolean.FALSE)));
    Assert.assertEquals(
        BooleanLabel.FALSE_LABEL,
        tree.classify(
            newClassificationDataSample(classificationHeader, Boolean.FALSE, Boolean.TRUE)));
    Assert.assertEquals(
        BooleanLabel.FALSE_LABEL,
        tree.classify(
            newClassificationDataSample(classificationHeader, Boolean.FALSE, Boolean.FALSE)));
  }
Esempio n. 7
0
  /**
   * @param args the command line arguments
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    PreProcessor p = new PreProcessor("census-income.data", "census-income-preprocessed.arff");

    p.smote();

    PreProcessor p_test =
        new PreProcessor("census-income.test", "census-income-test-preprocessed.arff");

    p_test.run();

    BufferedReader traindata =
        new BufferedReader(new FileReader("census-income-preprocessed.arff"));
    BufferedReader testdata =
        new BufferedReader(new FileReader("census-income-test-preprocessed.arff"));
    Instances traininstance = new Instances(traindata);
    Instances testinstance = new Instances(testdata);

    traindata.close();
    testdata.close();
    traininstance.setClassIndex(traininstance.numAttributes() - 1);
    testinstance.setClassIndex(testinstance.numAttributes() - 1);
    int numOfAttributes = testinstance.numAttributes();
    int numOfInstances = testinstance.numInstances();

    NaiveBayesClassifier nb = new NaiveBayesClassifier("census-income-preprocessed.arff");
    Classifier cnaive = nb.NBClassify();

    DecisionTree dt = new DecisionTree("census-income-preprocessed.arff");
    Classifier cls = dt.DTClassify();

    AdaBoost ab = new AdaBoost("census-income-preprocessed.arff");
    AdaBoostM1 m1 = ab.AdaBoostDTClassify();

    BaggingMethod b = new BaggingMethod("census-income-preprocessed.arff");
    Bagging bag = b.BaggingDTClassify();

    SVM s = new SVM("census-income-preprocessed.arff");
    SMO svm = s.SMOClassifier();

    knn knnclass = new knn("census-income-preprocessed.arff");
    IBk knnc = knnclass.knnclassifier();

    Logistic log = new Logistic();
    log.buildClassifier(traininstance);

    int match = 0;
    int error = 0;
    int greater = 0;
    int less = 0;

    for (int i = 0; i < numOfInstances; i++) {
      String predicted = "";
      greater = 0;
      less = 0;
      double predictions[] = new double[8];

      double pred = cls.classifyInstance(testinstance.instance(i));
      predictions[0] = pred;

      double abpred = m1.classifyInstance(testinstance.instance(i));
      predictions[1] = abpred;

      double naivepred = cnaive.classifyInstance(testinstance.instance(i));
      predictions[2] = naivepred;

      double bagpred = bag.classifyInstance(testinstance.instance(i));
      predictions[3] = bagpred;

      double smopred = svm.classifyInstance(testinstance.instance(i));
      predictions[4] = smopred;

      double knnpred = knnc.classifyInstance(testinstance.instance(i));
      predictions[5] = knnpred;

      for (int j = 0; j < 6; j++) {
        if ((testinstance.instance(i).classAttribute().value((int) predictions[j]))
                .compareTo(">50K")
            == 0) greater++;
        else less++;
      }

      if (greater > less) predicted = ">50K";
      else predicted = "<=50K";

      if ((testinstance.instance(i).stringValue(numOfAttributes - 1)).compareTo(predicted) == 0)
        match++;
      else error++;
    }

    System.out.println("Correctly classified Instances: " + match);
    System.out.println("Misclassified Instances: " + error);

    double accuracy = (double) match / (double) numOfInstances * 100;
    double error_percent = 100 - accuracy;
    System.out.println("Accuracy: " + accuracy + "%");
    System.out.println("Error: " + error_percent + "%");
  }
Esempio n. 8
0
  public static void main(String args[]) {

    System.out.println("Solution to HomeWork2 -- Machine learning");
    System.out.println("Implementation of ID3 Algorithm : ");
    System.out.println("==========================================");

    // Parse arguments
    int L = Integer.parseInt(args[0]);
    int K = Integer.parseInt(args[1]);
    String trainingset = args[2];
    String validationset = args[3];
    String testset = args[4];
    boolean toPrint = args[5].equalsIgnoreCase("yes");

    // ------------------------------------------------------------READING CSV FILE
    // ------------------------------------------------------------------------------//

    FileUtility fu1 = new FileUtility();
    parsedTrainingAttributes = fu1.parseCSVFile(trainingset);
    parsedValidationAttributes = fu1.parseCSVFile(validationset);
    parsedTestAttributes = fu1.parseCSVFile(testset);

    System.out.println("Read Successfull..........");
    System.out.println();

    // ------------------------------------------------------------PRINTING CSV FILE
    // ------------------------------------------------------------------------------//
    // Test print
    fu1.printSets(parsedTrainingAttributes);

    // ------------------------------------ Build tree using information gain and print its accuracy
    // over test set ------------------------------------------------//
    DecisionTree dtree = new DecisionTree();
    dtree.buildTree(parsedTrainingAttributes, new Node());

    // ------------------------------------------------------------PRINTING-------------------------------------------------------------------------------------------------------//

    System.out.println(
        "----------------------------------------------------------------------------------");
    AccuracyCalculator ac1 = new AccuracyCalculator();

    // Printing the decision tree and accuracy before pruning
    if (toPrint) {

      System.out.println("Decision tree before pruning: ");
      System.out.println("");
      System.out.println(dtree);
      System.out.println(
          "Accuracy of decision tree before pruning : "
              + ac1.getAccuracy(parsedTestAttributes, dtree.treeRootNode)
              + "%");
      System.out.println("Total Matched classes : " + (int) ac1.matchCount);
      System.out.println(
          "----------------------------------------------------------------------------------");
      System.out.println("");

      // Perform Post pruning the tree
      try {
        dtree.performPostPruning(L, K, parsedValidationAttributes);
      } catch (CloneNotSupportedException e) {
        e.printStackTrace();
      }

      System.out.println("Decision tree after pruning: ");
      System.out.println("");
      System.out.println(dtree);
      System.out.println(
          "Accuracy of decision tree after pruning: "
              + ac1.getAccuracy(parsedTestAttributes, dtree.treeRootNode)
              + "%");
      System.out.println("Total Matched classes : " + (int) ac1.matchCount);
      System.out.println(
          "----------------------------------------------------------------------------------");

    } else {

      System.out.println(
          "Accuracy of decision tree before pruning : "
              + ac1.getAccuracy(parsedTestAttributes, dtree.treeRootNode)
              + "%");
      System.out.println("Total Matched classes : " + (int) ac1.matchCount);
      System.out.println("");

      // Perform Post pruning the tree
      try {
        dtree.performPostPruning(L, K, parsedValidationAttributes);
      } catch (CloneNotSupportedException e) {
        e.printStackTrace();
      }

      System.out.println(
          "Accuracy of decision tree after pruning: "
              + ac1.getAccuracy(parsedTestAttributes, dtree.treeRootNode)
              + "%");
      System.out.println("Total Matched classes : " + (int) ac1.matchCount);
      System.out.println(
          "----------------------------------------------------------------------------------");
    }
  }
  /** @param args */
  public static void main(String[] args) {
    /* benötigte Listen um manuell einen baum zu erstellen
     * */
    List<String> decisions = new LinkedList<String>();
    List<Entry<Integer, Integer>> next_decisions = new LinkedList<Entry<Integer, Integer>>();
    List<String> conclusions = new LinkedList<String>();
    List<String> parameters = new LinkedList<String>();

    /*zum parsen der parameter aus der csv*/
    List<List<String>> params;

    System.out.println(
        "\n------------------------Von Hand erstellter Decision Tree---------------------------------------\n");
    // reihenfolge decisions wichtig unten für die übergänge, reihenfolge der parameter muss die
    // hier vorgegebene reihenfolge haben.
    DecisionTree TestTree = new DecisionTree(8, 8);
    TestTree.setParameter("test", 1);
    decisions.add("$1==J");
    TestTree.setParameterDescription(
        1, "Boolean: variable die anzeigt ob das Bild selbst erstellt wurde (J/N)");
    TestTree.setDecisionDescription(1, "Hast du das bild Selbst erstellt?");
    decisions.add("$2==J");
    TestTree.setParameterDescription(
        2,
        "Boolean: variable die anzeigt ob das Bild unter einer freien Lizenz veröffenlicht werdne soll (J/N)");
    TestTree.setDecisionDescription(2, "Willst du es unter einer freien Lizenz veröffenlichen?");
    decisions.add("$3>100");
    TestTree.setParameterDescription(3, "Integer: Das Alter des Bildes (J/N)");
    TestTree.setDecisionDescription(3, "Ist das Bild älter als 100 Jahre?");
    decisions.add("$4==J");
    TestTree.setParameterDescription(4, "Boolean: Urheber bekannt? (J/N)");
    TestTree.setDecisionDescription(4, "Ist der Urheber des Bildes bekannt?");
    decisions.add("$5==J");
    TestTree.setParameterDescription(5, "Boolean: Bildrechte dritter Auszuschließen? (J/N)");
    TestTree.setDecisionDescription(5, "Sind Bildrechte dritter Auszuschließen?");
    decisions.add("$6>70");
    TestTree.setParameterDescription(6, "Integer: vor wieviel Jahren ist der Urheber verstorben?");
    TestTree.setDecisionDescription(6, "Ist der Urheber vor mehr als 70 Jahren verstorben?");
    decisions.add("$7==J");
    TestTree.setParameterDescription(7, "Boolean: Einverständnis aller betroffenen? (J/N)");
    TestTree.setDecisionDescription(7, "Hast du das Einverständnis aller betroffenen?");
    decisions.add("$8==J");
    TestTree.setParameterDescription(8, "Boolean: Zustimmung des Urhebers? (J/N)");
    TestTree.setDecisionDescription(
        8, "Hat der Urheber zugestimmt das Bild unter eine freie Lizenz zu stellen?");
    TestTree.setDecisions(decisions);
    conclusions.add("hochladen");
    conclusions.add("nicht hochladen");
    TestTree.setConclusions(conclusions);
    // übergänge: indizies der nächsten decisions bei oben angegebener reihenfolge (start bei 1)
    // negative zahlen: der betrag ist der index einer conclusion, auch hier reihenfolge beachten
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(3, 2));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(-2, 5));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(4, -1));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(-2, 6));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(7, -1));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(8, 5));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(-2, -1));
    next_decisions.add(new AbstractMap.SimpleEntry<Integer, Integer>(-2, -1));
    TestTree.setNext_decisions(next_decisions);
    // parameter müssen in der oben angegeben reihenfolge hinzugefügt werden.
    parameters.add("N");
    parameters.add("J");
    parameters.add("70");
    parameters.add("J");
    parameters.add("N");
    parameters.add("80");
    parameters.add("N");
    parameters.add("J");

    System.out.println(
        TestTree.conclude_verbose(parameters.toArray(new String[parameters.size()])));

    // hier das gleiche setup wie oben, bloß die konfigurations für den baum aus der csv geparst und
    // statt J/N eben 1/0
    // keine descriptions
    System.out.println(
        "\n---------------------------Aus der csv geparste Decision trees------------------------------------\n");
    TestTree = Parser.parseTreeCsv(".\\bild.csv");
    parameters = new LinkedList<String>();
    parameters.add("0");
    parameters.add("1");
    parameters.add("70");
    parameters.add("1");
    parameters.add("0");
    parameters.add("80");
    parameters.add("0");
    parameters.add("1");
    System.out.println(
        TestTree.conclude_verbose(parameters.toArray(new String[parameters.size()])));

    int caseno = 1;
    // alles aus der csv geparst
    System.out.println(
        "\n-------------------------------Aus der csv geparster Decision tree und geparste parameter------------------------\n");
    System.out.println("\n--------Bild---------\n");
    params = Parser.parseParameters(".\\bild_cases.csv");
    for (List<String> param : params) {
      System.out.println("Case " + caseno + ": " + TestTree.conclude(param));
      caseno++;
    }
    caseno = 1;
    System.out.println("\n--------Bewerber--------\n");
    TestTree = Parser.parseTreeCsv(".\\bewerber.csv");
    params = Parser.parseParameters(".\\bewerber_cases.csv");
    for (List<String> param : params) {
      // System.out.println(TestTree.conclude_verbose(param.toArray(new String[param.size()])));
      System.out.println("Case " + caseno + ": " + TestTree.conclude(param));
      caseno++;
    }
    caseno = 1;
    System.out.println("\n--------Kunde--------\n");
    TestTree = Parser.parseTreeCsv(".\\kunde.csv");
    params = Parser.parseParameters(".\\kunde_cases.csv");
    for (List<String> param : params) {

      System.out.println(
          "Case "
              + caseno
              + ": "
              + TestTree.conclude_verbose(param.toArray(new String[param.size()])));
      System.out.println("\n----------------\n");
      //	System.out.println(TestTree.conclude_(param));
      caseno++;
    }
    caseno = 1;
    System.out.println("\n--------Gehalt--------\n");
    TestTree = Parser.parseTreeCsv(".\\gehalt.csv");
    params = Parser.parseParameters(".\\gehalt_cases.csv");
    for (List<String> param : params) {

      System.out.println(
          "Case "
              + caseno
              + ": "
              + TestTree.conclude_verbose(param.toArray(new String[param.size()])));
      System.out.println("\n----------------\n");
      //	System.out.println(TestTree.conclude_(param));
      caseno++;
    }
  }