예제 #1
0
  // 构造一个tri-trainer分类器。
  public Tritrainer(
      String classifier, String trainingIns_File, String testIns_File, double precentage) {
    try {
      this.classifier1 = (Classifier) Class.forName(classifier).newInstance();
      this.classifier2 = (Classifier) Class.forName(classifier).newInstance();
      this.classifier3 = (Classifier) Class.forName(classifier).newInstance();

      Instances trainingInstances = Util.getInstances(trainingIns_File);

      // 将trainIns_File按照precentage和(1-precentage)的比例切割成labeledIns和unlabeledIns;
      int length = trainingInstances.numInstances();
      int i = new Double(length * precentage).intValue();
      labeledIns = new Instances(trainingInstances, 0);
      for (int j = 0; j < i; j++) {
        labeledIns.add(trainingInstances.firstInstance());
        trainingInstances.delete(0);
      }
      unlabeledIns = trainingInstances;
      testIns = Util.getInstances(testIns_File);

      Init();
    } catch (Exception e) {

    }
  }
예제 #2
0
  public static ArrayList<Integer> getProfiles(Instances inst, List<Integer> marks)
      throws Exception {

    //		Instances inst = Utils.prepareProfileMatcherData(schoolNo, grade, term, subjects);

    //		ReplaceMissingValues rmv = new ReplaceMissingValues();
    //		rmv.setInputFormat(inst);
    //		inst = Filter.useFilter(inst, rmv);

    for (int i = 0; i < inst.numAttributes(); i++) {
      inst.deleteWithMissing(i);
    }

    KDTree tree = new KDTree();
    tree.setMeasurePerformance(true);

    try {
      tree.setInstances(inst);

      EuclideanDistance df = new EuclideanDistance(inst);
      df.setDontNormalize(true);
      df.setAttributeIndices("2-last");

      tree.setDistanceFunction(df);

    } catch (Exception e) {
      e.printStackTrace();
    }

    Instances neighbors = null;

    Instances test = CFilter.createInstance(112121, (ArrayList<Integer>) marks);

    Instance p = test.firstInstance();

    try {
      neighbors = tree.kNearestNeighbours(p, 50);
    } catch (Exception e) {
      e.printStackTrace();
    }
    //		System.out.println(tree.getPerformanceStats().getTotalPointsVisited());

    //		System.out.println(nn1 + " is the nearest neigbor for " + p);
    //		System.out.println(nn2 + " is the second nearest neigbor for " + p);

    ArrayList<Integer> profiles = new ArrayList<Integer>();
    for (int i = 0; i < neighbors.numInstances(); i++) {
      System.out.println(neighbors.instance(i));
      profiles.add(Integer.valueOf(neighbors.instance(i).toString(0)));
    }

    // Now we can also easily compute the distances as the KDTree does it

    DistanceFunction df = tree.getDistanceFunction();
    //		System.out.println("The distance between" + nn1 + " and " + p + " is " + df.distance(nn1,
    // p));
    //		System.out.println("The distance between" + nn2 + " and " + p + " is " + df.distance(nn2,
    // p));
    return profiles;
  }
예제 #3
0
  public void dataAvailable(String inputStreamName, StreamElement data) {

    // mapping the input stream to an instance and setting its dataset
    String[] dfn = data.getFieldNames().clone();
    Byte[] dft = data.getFieldTypes().clone();
    Serializable[] da = data.getData().clone();
    data = new StreamElement(dfn, dft, da, data.getTimeStamp());
    Instance i = instanceFromStream(data);
    if (att.size() == 0) {
      att = attFromStream(data);
    }
    dataset = new Instances("input", att, 0);
    dataset.setClassIndex(classIndex);
    if (i != null) {
      dataset.add(i);
      i = dataset.firstInstance();

      boolean success = true;

      // extracting latitude/longitude
      Double center_lat = i.value(1);
      Double center_long = i.value(2);

      // filling the grid with predictions/extrapolations
      Double[][] rawData = new Double[gridSize][gridSize];
      for (int j = 0; j < gridSize; j++) {
        for (int k = 0; k < gridSize; k++) {
          i.setValue(1, center_lat - (cellSize * gridSize / 2) + cellSize * j);
          i.setValue(2, center_long - (cellSize * gridSize / 2) + cellSize * k);
          rawData[j][k] = ms.predict(i);
          success = success && (rawData[j][k] != null);
        }
      }

      // preparing the output

      Serializable[] stream = new Serializable[7];
      try {

        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        ObjectOutputStream oos = new ObjectOutputStream(bos);
        oos.writeObject(rawData);
        oos.flush();
        oos.close();
        bos.close();

        stream[0] = new Integer(gridSize);
        stream[1] = new Integer(gridSize);
        stream[2] = new Double(center_lat - (cellSize * gridSize / 2));
        stream[3] = new Double(center_long - (cellSize * gridSize / 2));
        stream[4] = new Double(cellSize);
        stream[5] = new Double(0);
        stream[6] = bos.toByteArray();

      } catch (IOException e) {
        logger.warn(e.getMessage(), e);
        success = false;
      }

      if (success) {
        StreamElement se = new StreamElement(getOutputFormat(), stream, data.getTimeStamp());
        dataProduced(se);
      } else {
        logger.warn("Prediction error. Something get wrong with the prediction.");
      }

    } else {
      logger.warn(
          "Predicting instance has wrong attibutes, please check the model and the inputs.");
    }
  }