コード例 #1
0
  protected static void golf(InstanceHolder trainingSet, InstanceHolder testSet) throws Exception {
    int numatts = 3;
    int fromAtt = 8;
    InstanceHolder filteredTrain = new InstanceHolder(trainingSet.getNumberOfClasses(), numatts);
    InstanceHolder filteredTest = new InstanceHolder(testSet.getNumberOfClasses(), numatts);
    for (int i = 0; i < trainingSet.size(); i++) {
      SparseVector instance = new SparseVector(numatts);
      for (int j = 0; j < numatts; j++) {
        instance.put(j, trainingSet.getInstance(i).get(j + fromAtt));
      }
      filteredTrain.add(instance, trainingSet.getLabel(i));
    }
    for (int i = 0; i < testSet.size(); i++) {
      SparseVector instance = new SparseVector(numatts);
      for (int j = 0; j < numatts; j++) {
        instance.put(j, testSet.getInstance(i).get(j + fromAtt));
      }
      filteredTest.add(instance, testSet.getLabel(i));
    }
    KMeans kMeans = new KMeans(numOfClusters, 250.0);
    Random r = new Random(seed);
    for (int i = 0; i < 2 * filteredTrain.size(); i++) {
      int index = r.nextInt(filteredTrain.size());
      kMeans.update(filteredTrain.getInstance(index), filteredTrain.getLabel(index));
    }

    Model[] models = new Model[numOfClusters];
    for (int i = 0; i < numOfClusters; i++) {
      models[i] = (Model) Class.forName(modelName).newInstance();
      models[i].init(prefix);
      models[i].setNumberOfClasses((int) numberOfRatings);
    }

    for (int i = 0; i < 10 * trainingSet.size(); i++) {
      int index = r.nextInt(trainingSet.size());
      int clusterID = (int) kMeans.predict(filteredTrain.getInstance(index));
      models[clusterID].update(trainingSet.getInstance(index), trainingSet.getLabel(index));
      // if (i%100 == 0) System.out.println(i + "\t" + evaluate(models[clusterID], trainingSet) +
      // "\t" + evaluate(models[clusterID], testSet));// + "\t" + models[clusterID]);
    }

    double MAError = 0.0;
    double RMSError = 0.0;
    for (int i = 0; i < testSet.size(); i++) {
      int clusterID = (int) kMeans.predict(filteredTest.getInstance(i));
      double predicted = models[clusterID].predict(testSet.getInstance(i));
      double expected = testSet.getLabel(i);
      // System.out.println(expected + "\t" + predicted);
      MAError += Math.abs(expected - predicted);
      RMSError += Math.pow(((expected - predicted) / divErr), 2);
    }
    MAError /= testSet.size() * divErr;
    RMSError /= testSet.size();
    RMSError = Math.sqrt(RMSError);
    System.out.println("GoLF MAE: " + MAError);
    System.out.println("GoLF RMSE: " + RMSError);
    /*for (int i = 0; i < numOfClusters; i++) {
      System.out.println(models[i]);
    }*/
  }
コード例 #2
0
  /**
   * This method parses the given file into collections of instances and corresponding class labels.
   *
   * @param file the file that has to be parsed
   * @throws IOException if file reading error occurs.
   */
  protected InstanceHolder parseFile(final File file) throws IOException {
    // throw exception if the file does not exist or null
    if (file == null || !file.exists()) {
      throw new RuntimeException("The file \"" + file.toString() + "\" is null or does not exist!");
    }
    // InstanceHolder holder = new InstanceHolder();
    Vector<SparseVector> instances = new Vector<SparseVector>();
    Vector<Double> labels = new Vector<Double>();
    BufferedReader br = new BufferedReader(new FileReader(file));
    int numberOfClasses = -1;
    int numberOfFeatures = -1;
    String line;
    String[] split;
    int c = 0;
    double rate;
    int userId;
    int itemId;
    SparseVector instance;
    while ((line = br.readLine()) != null) {
      c++;
      // eliminating empty and comment lines
      if (line.length() == 0 || line.startsWith("#")) {
        continue;
      }
      // eliminating comments and white spaces from the endings of the line
      line = line.replaceAll("#.*", "").trim();
      // splitting line at white spaces and at colons
      split = line.split("\\s");
      // throwing exception if the line is invalid (= has even number of tokens, since
      // a valid line has a class label and pairs of indices and corresponding values)
      if (split.length != 3) {
        throw new RuntimeException(
            "The file \"" + file.toString() + "\" has invalid structure at line " + c);
      }
      userId = Integer.parseInt(split[0]) - 1;
      itemId = Integer.parseInt(split[1]) - 1;
      rate = Double.parseDouble(split[2]);

      /*if (numberOfClasses != Integer.MAX_VALUE && (rate <= 0.0 || rate != (int)rate)) {
        // not a regression problem => the label has to be an integer which is greater or equal than 0
        throw new RuntimeException("The rate value has to be integer and greater than 0, line " + c);
      }*/
      if (userId < 0.0) {
        throw new RuntimeException(
            "The user ID has to be integer and greater than or equal to 0, line " + c);
      }
      if (itemId < 0.0) {
        throw new RuntimeException(
            "The item ID has to be integer and greater than or equal to 0, line " + c);
      }
      if (rate > numberOfClasses) {
        numberOfClasses = (int) rate;
      }
      if (itemId >= numberOfFeatures) {
        numberOfFeatures = itemId + 1;
      }
      if (instances.size() <= userId) {
        for (int i = instances.size(); i <= userId; i++) {
          instances.add(new SparseVector(1));
          labels.add(0.0);
        }
      }

      instance = instances.get(userId);
      instance.put(itemId, rate);
      labels.set(userId, (double) userId);
    }
    br.close();
    return new InstanceHolder(
        instances,
        labels,
        (numberOfClasses == 1) ? 0 : numberOfClasses,
        numberOfFeatures); // 1-> indicating clustering
  }