コード例 #1
0
ファイル: LinearTest.java プロジェクト: podalv/Classifier
 @Test
 public void saveBinary_loadBinary() throws Exception {
   final Model model = createModel();
   final ByteArrayOutputStream stream = new ByteArrayOutputStream();
   model.saveBinary(stream);
   final Model newModel = Model.loadBinary(new ByteArrayInputStream(stream.toByteArray()));
   Assert.assertEquals(model.getBias(), newModel.getBias(), 0.001);
   Assert.assertEquals(model.getNrFeature(), newModel.getNrFeature());
   Assert.assertArrayEquals(model.getLabels(), newModel.getLabels());
   Assert.assertEquals(model.getNrClass(), newModel.getNrClass());
   Assert.assertArrayEquals(model.w, newModel.w, 0.001);
 }
コード例 #2
0
ファイル: ModelTests.java プロジェクト: gaieepo/HubTurbo
  @Test
  public void immutability() {
    Model other = new Model(modelUpdated);
    other.getIssues().add(new TurboIssue(REPO, 11, ""));
    assertEquals(modelUpdated, other);

    other = new Model(modelUpdated);
    other.getLabels().add(new TurboLabel(REPO, "aksdjl"));
    assertEquals(modelUpdated, other);

    other = new Model(modelUpdated);
    other.getMilestones().add(new TurboMilestone(REPO, 11, ""));
    assertEquals(modelUpdated, other);

    other = new Model(modelUpdated);
    other.getUsers().add(new TurboUser(REPO, ""));
    assertEquals(modelUpdated, other);
  }
コード例 #3
0
ファイル: Predict.java プロジェクト: Thunder1989/IR_Base
  /** <b>Note: The streams are NOT closed</b> */
  static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException {
    int correct = 0;
    int total = 0;
    double error = 0;
    double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;

    int nr_class = model.getNrClass();
    double[] prob_estimates = null;
    int n;
    int nr_feature = model.getNrFeature();
    if (model.bias >= 0) n = nr_feature + 1;
    else n = nr_feature;

    if (flag_predict_probability && !model.isProbabilityModel()) {
      throw new IllegalArgumentException(
          "probability output is only supported for logistic regression");
    }

    Formatter out = new Formatter(writer);

    if (flag_predict_probability) {
      int[] labels = model.getLabels();
      prob_estimates = new double[nr_class];

      printf(out, "labels");
      for (int j = 0; j < nr_class; j++) printf(out, " %d", labels[j]);
      printf(out, "\n");
    }

    String line = null;
    while ((line = reader.readLine()) != null) {
      List<Feature> x = new ArrayList<Feature>();
      StringTokenizer st = new StringTokenizer(line, " \t\n");
      double target_label;
      try {
        String label = st.nextToken();
        target_label = atof(label);
      } catch (NoSuchElementException e) {
        throw new RuntimeException("Wrong input format at line " + (total + 1), e);
      }

      while (st.hasMoreTokens()) {
        String[] split = COLON.split(st.nextToken(), 2);
        if (split == null || split.length < 2) {
          throw new RuntimeException("Wrong input format at line " + (total + 1));
        }

        try {
          int idx = atoi(split[0]);
          double val = atof(split[1]);

          // feature indices larger than those in training are not used
          if (idx <= nr_feature) {
            Feature node = new FeatureNode(idx, val);
            x.add(node);
          }
        } catch (NumberFormatException e) {
          throw new RuntimeException("Wrong input format at line " + (total + 1), e);
        }
      }

      if (model.bias >= 0) {
        Feature node = new FeatureNode(n, model.bias);
        x.add(node);
      }

      Feature[] nodes = new Feature[x.size()];
      nodes = x.toArray(nodes);

      double predict_label;

      if (flag_predict_probability) {
        assert prob_estimates != null;
        predict_label = Linear.predictProbability(model, nodes, prob_estimates);
        printf(out, "%g", predict_label);
        for (int j = 0; j < model.nr_class; j++) printf(out, " %g", prob_estimates[j]);
        printf(out, "\n");
      } else {
        predict_label = Linear.predict(model, nodes);
        printf(out, "%g\n", predict_label);
      }

      if (predict_label == target_label) {
        ++correct;
      }

      error += (predict_label - target_label) * (predict_label - target_label);
      sump += predict_label;
      sumt += target_label;
      sumpp += predict_label * predict_label;
      sumtt += target_label * target_label;
      sumpt += predict_label * target_label;
      ++total;
    }

    if (model.solverType.isSupportVectorRegression()) //
    {
      info("Mean squared error = %g (regression)%n", error / total);
      info(
          "Squared correlation coefficient = %g (regression)%n", //
          ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt))
              / ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt)));
    } else {
      info("Accuracy = %g%% (%d/%d)%n", (double) correct / total * 100, correct, total);
    }
  }
コード例 #4
0
  private void mainClassifierFunction(int option, String trainFile, String testFile, String ddgFile)
      throws IOException {
    // SentimentClassifierHindi this = new SentimentClassifierHindi();
    // int finalSize = this.SentimentClassifierHindi();
    int finalSize = this.generateFeature(option, trainFile, testFile, ddgFile);
    System.out.println("Hello aspectCategorizationSemEval2016!");

    // Create features
    Problem problem = new Problem();

    // Save X to problem
    double a[] = new double[this.trainingFeature.size()];
    File file = new File(rootDirectory + "\\dataset\\trainingLabels.txt");
    BufferedReader reader = new BufferedReader(new FileReader(file));
    String read;
    int count = 0;
    while ((read = reader.readLine()) != null) {
      // System.out.println(read);
      a[count++] = Double.parseDouble(read.toString());
    }

    // Feature[][] f = new Feature[][]{ {}, {}, {}, {}, {}, {} };

    // trainingFeature = trainingObject.getList();
    Feature[][] trainFeatureVector = new Feature[trainingFeature.size()][finalSize];

    System.out.println("Training Instances: " + trainingFeature.size());
    System.out.println("Feature Length: " + finalSize);
    System.out.println("Test Instances: " + testFeature.size());

    for (int i = 0; i < trainingFeature.size(); i++) {
      // System.out.println();
      // System.out.println(trainingFeature.get(i));
      System.out.println(i + " trained.");
      for (int j = 0; j < finalSize; j++) {
        // System.out.print(trainingFeature.get(i).get(j + 1)+" ");
        // trainingFeature.get(i).
        if (trainingFeature.get(i).containsKey(j + 1)) {
          // System.out.print(j + 1 + ", ");
          trainFeatureVector[i][j] = new FeatureNode(j + 1, trainingFeature.get(i).get(j + 1));
        } else {
          trainFeatureVector[i][j] = new FeatureNode(j + 1, 0.0);
        }
      }
      // System.out.println();
    }

    problem.l = trainingFeature.size(); // number of training examples
    problem.n = finalSize; // number of features
    problem.x = trainFeatureVector; // feature nodes
    problem.y = a; // target values ----

    BasicParser bp = new BasicParser();

    SolverType solver = SolverType.L2R_LR; // -s 7
    double C = 0.75; // cost of constraints violation
    double eps = 0.0001; // stopping criteria

    Parameter parameter = new Parameter(solver, C, eps);
    Model model = Linear.train(problem, parameter);
    File modelFile = new File("model");
    model.save(modelFile);

    // PrintWriter write = new PrintWriter(new BufferedWriter(new FileWriter(rootDirectory +
    // "\\dataset\\predictedLabels.txt")));
    PrintWriter write =
        new PrintWriter(
            new BufferedWriter(
                new FileWriter(
                    rootDirectory
                        + "\\dataset\\dataset_aspectCategorization\\predictedHotelsLabels.txt")));

    if (option == 1) {
      BufferedReader trainReader =
          new BufferedReader(
              new FileReader(
                  new File(
                      rootDirectory + "\\dataset\\dataset_aspectCategorization\\" + trainFile)));
      HashMap<String, Integer> id = new HashMap<String, Integer>();
      HashMap<String, String> review = new HashMap<String, String>();
      double[] val = new double[trainingFeature.size()];
      double[] tempVal = new double[trainingFeature.size()];
      LinearCopy.crossValidation(problem, parameter, 5, val, tempVal);
      for (int i = 0; i < trainingFeature.size(); i++) {
        int flag = 0;
        String tokens[] = trainReader.readLine().split("\\|");
        if (id.containsKey(tokens[1]) == true || tokens[2].compareToIgnoreCase("True") == 0) {
        } else {
          // System.out.println(tokens[1]);
          /*int max = -1;
          double probMax = -1.0;
          for(int j=0; j<13; j++){
              if(probMax<val[i][j]){
                  probMax = val[i][j];
                  max = j;
              }
          }*/
          // System.out.println(tempVal[i]);
          write.println((int) (val[i]));
          write.println("next");
          id.put(tokens[1], 1);
          System.out.println(tokens[1] + "\t" + (int) (val[i]));
          if (review.containsKey(tokens[1])) {
            System.out.println(tokens[3]);
            System.out.println(review.get(tokens[1]));
          } else {
            review.put(tokens[1], tokens[3]);
          }
        } /*else{
              for (int j = 0; j < 13; j++) {
                  //System.out.print(val[i][j]+", ");
                  if (val[i] >= 0.185) {
                      flag = 1;
                      //System.out.println("i");
                      write.println(j + 1);
                  }
              }
              if (flag == 1) {
                  write.println("next");
              } else {
                  write.println("-1");
                  write.println("next");
              }
              //write.println(prediction);
              id.put(tokens[1], 1);
              //System.out.println();
          }*/
      }
      write.close();
      return;
    }

    if (option == 3) {
      System.out.println(rootDirectory);
      BufferedReader testReader =
          new BufferedReader(
              new FileReader(
                  new File(
                      rootDirectory + "\\dataset\\dataset_aspectCategorization\\" + testFile)));
      HashMap<String, Integer> id = new HashMap<String, Integer>();
      model = Model.load(modelFile);
      int countNext = 0;
      for (int i = 0; i < testFeature.size(); i++) {
        // System.out.println(i+", "+testFeature.size()+", "+testFeature.get(i).size());
        Feature[] instance = new Feature[testFeature.get(i).size()];
        int j = 0;
        for (Map.Entry<Integer, Double> entry : testFeature.get(i).entrySet()) {
          // System.out.print(entry.getKey() + ": " + entry.getValue() + ";   ");
          // listOfMaps.get(i).put(start + entry.getKey(), entry.getValue());
          // do stuff
          instance[j++] = new FeatureNode(entry.getKey(), entry.getValue());
        }

        // double d = LinearCopy.predict(model, instance);

        double[] predict = new double[85];
        double prediction = LinearCopy.predictProbability(model, instance, predict);

        int labelMap[] = new int[13];
        labelMap = model.getLabels();

        for (int ar = 0; ar < labelMap.length; ar++) {
          System.out.println("********************** " + ar + ": " + labelMap[ar]);
        }

        // System.out.println(prediction);
        // Arrays.sort(predict, Collections.reverseOrder());
        // System.out.println();
        // double prediction = LinearCopy.predict(model, instance);
        String tokens[] = testReader.readLine().split("\\|");
        // System.out.println(tokens[1]);

        int flag = -1;
        if (id.containsKey(tokens[1]) == true || tokens[2].compareToIgnoreCase("True") == 0) {
          flag = 4;
          // System.out.println("OutofScope: "+tokens[1]);
        } else if (tokens[3].compareToIgnoreCase("abc") == 0) {
          flag = 2;
          System.out.println(tokens[1]);
          write.println("-1");
          write.println("next");
          countNext++;
          id.put(tokens[1], 1);
        } else {
          flag = 0;
          for (int p = 0; p < 85; p++) {
            if (predict[p] >= 0.128) {
              flag = 1;
              write.println(labelMap[p]);
            }
          }
          if (flag == 1) {
            countNext++;
            write.println("next");
          } else {
            countNext++;
            write.println("-1");
            write.println("next");
          }

          // write.println((int)d);
          // write.println("next");

          /*write.println(prediction);
          write.println("next");*/
          id.put(tokens[1], 1);
        }

        if (flag == -1) {
          System.out.println("-1,   " + tokens[1]);
        }
      }

      write.close();
      System.out.println("count " + countNext);
    }
    write.close();
  }
コード例 #5
0
ファイル: ModelTests.java プロジェクト: gaieepo/HubTurbo
  @Test
  public void equality() {

    assertEquals(modelUpdated, modelUpdated);
    assertNotEquals(modelUpdated, null);
    assertNotEquals(modelUpdated, 1);

    // Empty signature

    assertEquals(modelEmptySig, modelEmptySig2);

    // Copy correctness

    assertEquals(modelCopyUpdated, modelUpdated);
    assertEquals(modelCopyNotUpdated, modelEmptySig);

    // Update signature

    // Tested by changing one element in the model and ensuring inequality

    assertEquals(modelEmptySig.hashCode(), modelEmptySig2.hashCode());
    assertEquals(modelCopyUpdated.hashCode(), modelUpdated.hashCode());
    assertNotEquals(modelEmptySig.hashCode(), modelUpdated.hashCode());

    Model model =
        new Model(
            "something",
            modelUpdated.getIssues(),
            modelUpdated.getLabels(),
            modelUpdated.getMilestones(),
            modelUpdated.getUsers(),
            modelUpdated.getUpdateSignature());
    assertNotEquals(model.hashCode(), modelUpdated.hashCode());
    assertNotEquals(model, modelUpdated);

    List<TurboIssue> issues = new ArrayList<>(modelUpdated.getIssues());
    issues.add(new TurboIssue(REPO, 11, "something"));
    model =
        new Model(
            REPO,
            issues,
            modelUpdated.getLabels(),
            modelUpdated.getMilestones(),
            modelUpdated.getUsers(),
            modelUpdated.getUpdateSignature());
    assertNotEquals(model.hashCode(), modelUpdated.hashCode());
    assertNotEquals(model, modelUpdated);

    List<TurboLabel> labels = new ArrayList<>(modelUpdated.getLabels());
    labels.add(new TurboLabel(REPO, "Label 11"));
    model =
        new Model(
            REPO,
            modelUpdated.getIssues(),
            labels,
            modelUpdated.getMilestones(),
            modelUpdated.getUsers(),
            modelUpdated.getUpdateSignature());
    assertNotEquals(model.hashCode(), modelUpdated.hashCode());
    assertNotEquals(model, modelUpdated);

    List<TurboMilestone> milestones = new ArrayList<>(modelUpdated.getMilestones());
    milestones.add(new TurboMilestone(REPO, 11, "something"));
    model =
        new Model(
            REPO,
            modelUpdated.getIssues(),
            modelUpdated.getLabels(),
            milestones,
            modelUpdated.getUsers(),
            modelUpdated.getUpdateSignature());
    assertNotEquals(model.hashCode(), modelUpdated.hashCode());
    assertNotEquals(model, modelUpdated);

    List<TurboUser> users = new ArrayList<>(modelUpdated.getUsers());
    users.add(new TurboUser(REPO, "someone"));
    model =
        new Model(
            REPO,
            modelUpdated.getIssues(),
            modelUpdated.getLabels(),
            modelUpdated.getMilestones(),
            users,
            modelUpdated.getUpdateSignature());
    assertNotEquals(model.hashCode(), modelUpdated.hashCode());
    assertNotEquals(model, modelUpdated);
  }
コード例 #6
0
ファイル: ModelTests.java プロジェクト: gaieepo/HubTurbo
  @Test
  public void getters() {
    // ID
    assertEquals(REPO, modelUpdated.getRepoId());
    assertEquals(modelUpdated.getRepoId(), modelUpdated.getRepoId());

    // Signature
    assertEquals(true, modelEmptySig.getUpdateSignature().isEmpty());
    assertEquals(modelEmptySig.getUpdateSignature(), UpdateSignature.EMPTY);
    assertEquals(modelEmptySig.getUpdateSignature(), modelEmptySig2.getUpdateSignature());

    // Resources
    // Issues
    ArrayList<Integer> issueIds = new ArrayList<>();
    for (int i = 1; i <= DummyRepoState.NO_OF_DUMMY_ISSUES; i++) {
      issueIds.add(i);
    }
    Collections.sort(issueIds); // 1, 2..10
    int issueCount = 1;
    for (TurboIssue issue : modelUpdated.getIssues()) {
      assertEquals(issueCount, modelUpdated.getIssueById(issueCount).get().getId());
      assertEquals(issueIds.get(issueCount - 1).intValue(), issue.getId());
      issueCount++;
    }

    // Labels
    ArrayList<String> labelNames = new ArrayList<>();
    for (int i = 1; i <= DummyRepoState.NO_OF_DUMMY_ISSUES; i++) {
      labelNames.add("Label " + i);
    }
    Collections.sort(labelNames); // Label 1, Label 10..12, Label 2..9
    int labelCount = 1;
    for (TurboLabel label : modelUpdated.getLabels()) {
      if (label.getFullName().startsWith("Label")) {
        assertEquals(labelNames.get(labelCount - 1), label.getFullName());
        assertEquals(
            "Label " + labelCount,
            modelUpdated.getLabelByActualName("Label " + labelCount).get().getFullName());
        labelCount++;
      }
    }

    // Milestones
    ArrayList<Integer> milestoneIds = new ArrayList<>();
    for (int i = 1; i <= DummyRepoState.NO_OF_DUMMY_ISSUES; i++) {
      milestoneIds.add(i);
    }
    Collections.sort(milestoneIds); // 1, 2..10
    int milestoneCount = 1;
    for (TurboMilestone milestone : modelUpdated.getMilestones()) {
      assertEquals(milestoneCount, milestone.getId());
      assertEquals(
          milestoneIds.get(milestoneCount - 1).intValue(),
          modelUpdated.getMilestoneById(milestoneCount).get().getId());
      assertEquals(
          "Milestone " + milestoneCount,
          modelUpdated.getMilestoneByTitle("Milestone " + milestoneCount).get().getTitle());
      milestoneCount++;
    }

    // Users
    ArrayList<String> userLogins = new ArrayList<>();
    for (int i = 1; i <= DummyRepoState.NO_OF_DUMMY_ISSUES; i++) {
      userLogins.add("User " + i);
    }
    Collections.sort(userLogins); // User 1, User 10, User 2..9
    int userCount = 1;
    for (TurboUser user : modelUpdated.getUsers()) {
      assertEquals(userLogins.get(userCount - 1), user.getLoginName());
      assertEquals(
          "User " + userCount,
          modelUpdated.getUserByLogin("User " + userCount).get().getLoginName());
      userCount++;
    }
  }