コード例 #1
0
 /**
  * Evaluate an iterative recommender on the folds of a dataset split, display results on STDOUT.
  *
  * @param recommender a rating predictor
  * @param num_folds the number of folds
  * @param max_iter the maximum number of iterations
  * @param find_iter the report interval
  * @throws Exception
  */
 public static void doIterativeCrossValidation(
     RatingPredictor recommender, int num_folds, int max_iter, Integer find_iter)
     throws Exception {
   RatingCrossValidationSplit split =
       new RatingCrossValidationSplit(recommender.getRatings(), num_folds);
   doIterativeCrossValidation(recommender, split, max_iter, find_iter);
 }
コード例 #2
0
  /**
   * Evaluate on the folds of a dataset split.
   *
   * @param recommender a rating predictor
   * @param num_folds the number of folds
   * @param compute_fit if set to true measure fit on the training data as well
   * @param show_results if set to true to print results to STDERR
   * @return a dictionary containing the average results over the different folds of the split
   * @throws Exception
   */
  public static RatingPredictionEvaluationResults doCrossValidation(
      RatingPredictor recommender, Integer num_folds, Boolean compute_fit, Boolean show_results)
      throws Exception {

    if (num_folds == null) num_folds = 5;
    if (compute_fit == null) compute_fit = false;
    if (show_results == null) show_results = false;

    RatingCrossValidationSplit split =
        new RatingCrossValidationSplit(recommender.getRatings(), num_folds);
    return doCrossValidation(recommender, split, compute_fit, show_results);
  }
コード例 #3
0
  /**
   * Evaluate an iterative recommender on the folds of a dataset split, display results on STDOUT.
   *
   * @param recommender a rating predictor
   * @param split a rating dataset split
   * @param max_iter the maximum number of iterations
   * @param find_iter the report interval
   * @throws Exception
   */
  public static void doIterativeCrossValidation(
      RatingPredictor recommender, ISplit<IRatings> split, int max_iter, Integer find_iter)
      throws Exception {
    if (find_iter == null) find_iter = 1;

    if (!(recommender instanceof IIterativeModel))
      throw new IllegalArgumentException("recommender must be of type IIterativeModel");

    RatingPredictor[] split_recommenders = new RatingPredictor[split.numberOfFolds()];
    IIterativeModel[] iterative_recommenders = new IIterativeModel[split.numberOfFolds()];

    // Initial training and evaluation
    for (int i = 0; i < split.numberOfFolds(); i++) {
      try {
        split_recommenders[i] = recommender.clone(); // to avoid changes : recommender
        split_recommenders[i].setRatings(split.train().get(i));
        split_recommenders[i].train();
        iterative_recommenders[i] = (IIterativeModel) split_recommenders[i];
        HashMap<String, Double> fold_results =
            Ratings.evaluate(split_recommenders[i], split.test().get(i));
        System.out.println(
            "fold "
                + i
                + " "
                + fold_results
                + " iteration "
                + iterative_recommenders[i].getNumIter());
      } catch (Exception e) {
        System.err.println("===> ERROR: " + e.getMessage());
        throw e;
      }
    }

    // Iterative training and evaluation
    for (int it = iterative_recommenders[0].getNumIter() + 1; it <= max_iter; it++) {
      for (int i = 0; i < split.numberOfFolds(); i++) {
        try {
          iterative_recommenders[i].iterate();

          if (it % find_iter == 0) {
            HashMap<String, Double> fold_results =
                Ratings.evaluate(split_recommenders[i], split.test().get(i));
            System.out.println("fold " + i + " " + fold_results + " iteration " + it);
          }
        } catch (Exception e) {
          System.err.println("===> ERROR: " + e.getMessage());
          throw e;
        }
      }
    }
  }
コード例 #4
0
  /**
   * Evaluate on the folds of a dataset split.
   *
   * @param recommender a rating predictor
   * @param split a rating dataset split
   * @param compute_fit if set to true measure fit on the training data as well
   * @param show_results set to true to print results to STDERR
   * @return a dictionary containing the average results over the different folds of the split
   * @throws Exception
   */
  public static RatingPredictionEvaluationResults doCrossValidation(
      RatingPredictor recommender,
      ISplit<IRatings> split,
      Boolean compute_fit,
      Boolean show_results)
      throws Exception {

    if (compute_fit == null) compute_fit = false;
    if (show_results == null) show_results = false;

    RatingPredictionEvaluationResults avg_results = new RatingPredictionEvaluationResults();

    for (int i = 0; i < split.numberOfFolds(); i++) {
      try {
        RatingPredictor split_recommender = recommender.clone(); // to avoid changes : recommender
        split_recommender.setRatings(split.train().get(i));
        split_recommender.train();
        HashMap<String, Double> fold_results =
            Ratings.evaluate(split_recommender, split.test().get(i));
        if (compute_fit) fold_results.put("fit", new Double(Ratings.computeFit(split_recommender)));

        for (String key : fold_results.keySet())
          if (avg_results.containsKey(key))
            avg_results.put(key, avg_results.get(key) + fold_results.get(key));
          else avg_results.put(key, fold_results.get(key));

        if (show_results) System.out.println("fold " + i + " " + fold_results);
      } catch (Exception e) {
        System.err.println("===> ERROR: " + e.getMessage());
        throw e;
      }
    }

    for (String key : Ratings.getMeasures()) {
      avg_results.put(key, avg_results.get(key) / split.numberOfFolds());
    }
    return avg_results;
  }
コード例 #5
0
  static void loadData(
      String data_dir,
      String user_attributes_file,
      String item_attributes_file,
      String user_relation_file,
      String item_relation_file,
      boolean static_data)
      throws Exception {

    long start = Calendar.getInstance().getTimeInMillis();

    // Read training data
    if ((recommender instanceof TimeAwareRatingPredictor || chronological_split != null)
        && file_format != RatingFileFormat.MOVIELENS_1M) {
      training_data =
          TimedRatingData.read(
              Utils.combine(data_dir, training_file), user_mapping, item_mapping, false);
    } else {
      if (file_format == RatingFileFormat.DEFAULT)
        training_data =
            static_data
                ? StaticRatingData.read(
                    Utils.combine(data_dir, training_file),
                    user_mapping,
                    item_mapping,
                    rating_type,
                    false)
                : RatingData.read(
                    Utils.combine(data_dir, training_file), user_mapping, item_mapping, false);
      else if (file_format == RatingFileFormat.IGNORE_FIRST_LINE)
        training_data =
            static_data
                ? StaticRatingData.read(
                    Utils.combine(data_dir, training_file),
                    user_mapping,
                    item_mapping,
                    rating_type,
                    true)
                : RatingData.read(
                    Utils.combine(data_dir, training_file), user_mapping, item_mapping, true);
      else if (file_format == RatingFileFormat.MOVIELENS_1M)
        training_data =
            MovieLensRatingData.read(
                Utils.combine(data_dir, training_file), user_mapping, item_mapping);
      else if (file_format == RatingFileFormat.KDDCUP_2011)
        training_data =
            org.mymedialite.io.kddcup2011.Ratings.read(Utils.combine(data_dir, training_file));
    }
    recommender.setRatings(training_data);

    // User attributes
    if (user_attributes_file != null)
      user_attributes =
          AttributeData.read(
              Utils.combine(data_dir, user_attributes_file), user_mapping, attribute_mapping);

    if (recommender instanceof IUserAttributeAwareRecommender)
      ((IUserAttributeAwareRecommender) recommender).setUserAttributes(user_attributes);

    // Item attributes
    if (item_attributes_file != null)
      item_attributes =
          AttributeData.read(
              Utils.combine(data_dir, item_attributes_file), item_mapping, attribute_mapping);

    if (recommender instanceof IItemAttributeAwareRecommender)
      ((IItemAttributeAwareRecommender) recommender).setItemAttributes(item_attributes);

    // User relation
    if (recommender instanceof IUserRelationAwareRecommender) {
      ((IUserRelationAwareRecommender) recommender)
          .setUserRelation(
              RelationData.read(Utils.combine(data_dir, user_relation_file), user_mapping));
      System.out.println(
          "relation over " + ((IUserRelationAwareRecommender) recommender).numUsers() + " users");
    }

    // Item relation
    if (recommender instanceof IItemRelationAwareRecommender) {
      ((IItemRelationAwareRecommender) recommender)
          .setItemRelation(
              RelationData.read(Utils.combine(data_dir, item_relation_file), item_mapping));
      System.out.println(
          "relation over "
              + ((IItemRelationAwareRecommender) recommender).getNumItems()
              + " items");
    }

    // Read test data
    if (test_file != null) {
      if (recommender instanceof TimeAwareRatingPredictor
          && file_format != RatingFileFormat.MOVIELENS_1M)
        test_data =
            TimedRatingData.read(
                Utils.combine(data_dir, test_file), user_mapping, item_mapping, false);
      else if (file_format == RatingFileFormat.MOVIELENS_1M)
        test_data =
            MovieLensRatingData.read(
                Utils.combine(data_dir, test_file), user_mapping, item_mapping);
      else if (file_format == RatingFileFormat.KDDCUP_2011)
        test_data =
            org.mymedialite.io.kddcup2011.Ratings.read(Utils.combine(data_dir, training_file));
      else
        test_data =
            StaticRatingData.read(
                Utils.combine(data_dir, test_file),
                user_mapping,
                item_mapping,
                rating_type,
                file_format == RatingFileFormat.IGNORE_FIRST_LINE);
    }

    System.out.println(
        "Loading time: "
            + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000
            + " seconds");
    System.out.println("Memory: " + Memory.getUsage() + " MB");
  }
コード例 #6
0
  static void checkParameters() {
    if (online_eval && !(recommender instanceof IIncrementalRatingPredictor))
      usage(
          "Recommender "
              + recommender.getClass().getName()
              + " does not support incremental updates, which are necessary for an online experiment.");

    if (training_file == null && load_model_file == null)
      usage("Please provide either --training-file=FILE or --load-model=FILE.");

    if (cross_validation == 1) usage("--cross-validation=K requires K to be at least 2.");

    if (show_fold_results && cross_validation == 0)
      usage("--show-fold-results only works with --cross-validation=K.");

    if (cross_validation > 1 && test_ratio != 0)
      usage("--cross-validation=K and --test-ratio=NUM are mutually exclusive.");

    if (cross_validation > 1 && prediction_file != null)
      usage("--cross-validation=K and --prediction-file=FILE are mutually exclusive.");

    if (test_file == null
        && test_ratio == 0
        && cross_validation == 0
        && save_model_file == null
        && chronological_split == null)
      usage(
          "Please provide either test-file=FILE, --test-ratio=NUM, --cross-validation=K, --chronological-split=NUM|DATETIME, or --save-model=FILE.");

    if (recommender instanceof IUserAttributeAwareRecommender && user_attributes_file == null)
      usage("Recommender expects --user-attributes=FILE.");

    if (recommender instanceof IItemAttributeAwareRecommender && item_attributes_file == null)
      usage("Recommender expects --item-attributes=FILE.");

    if (recommender instanceof IUserRelationAwareRecommender && user_relations_file == null)
      usage("Recommender expects --user-relations=FILE.");

    if (recommender instanceof IItemRelationAwareRecommender && user_relations_file == null)
      usage("Recommender expects --item-relations=FILE.");

    // handling of --chronological-split
    if (chronological_split != null) {
      try {
        chronological_split_ratio = Double.parseDouble(chronological_split);
      } catch (NumberFormatException e) {
        usage(
            "Unable to parse chronological_split_ratio "
                + chronological_split_ratio
                + " as double");
      }

      if (chronological_split_ratio == -1)
        try {
          chronological_split_time = dateFormat.parse(chronological_split);
        } catch (ParseException e) {
          usage(
              "Could not interpret argument of --chronological-split as number or date and time: "
                  + chronological_split);
        }

      // check for conflicts
      if (cross_validation > 1)
        usage(
            "--cross-validation=K and --chronological-split=NUM|DATETIME are mutually exclusive.");

      if (test_ratio > 1)
        usage("--test-ratio=NUM and --chronological-split=NUM|DATETIME are mutually exclusive.");
    }
  }
コード例 #7
0
  public static void main(String[] args) throws Exception {

    // Handlers for uncaught exceptions and interrupts
    Thread.setDefaultUncaughtExceptionHandler(new Handlers());
    Runtime.getRuntime()
        .addShutdownHook(
            new Thread() {
              @Override
              public void run() {
                displayStats();
              }
            });

    // Recommender arguments
    String method = null;
    String recommender_options = "";

    // Help/version
    boolean show_help = false;
    boolean show_version = false;

    // Arguments for iteration search
    int max_iter = 100;
    double epsilon = 0;
    double rmse_cutoff = Double.MAX_VALUE;
    double mae_cutoff = Double.MAX_VALUE;

    // Data arguments
    String data_dir = "";

    // Other arguments
    boolean search_hp = false;
    int random_seed = -1;
    String prediction_line = "{0}\t{1}\t{2}";

    for (String arg : args) {
      int div = arg.indexOf("=") + 1;
      String name;
      String value;
      if (div > 0) {
        name = arg.substring(0, div);
        value = arg.substring(div);
      } else {
        name = arg;
        value = null;
      }

      // String-valued options
      if (name.equals("--training-file=")) training_file = value;
      else if (name.equals("--test-file=")) test_file = value;
      else if (name.equals("--recommender=")) method = value;
      else if (name.equals("--recommender-options=")) recommender_options += " " + value;
      else if (name.equals("--data-dir=")) data_dir = value;
      else if (name.equals("--user-attributes=")) user_attributes_file = value;
      else if (name.equals("--item-attributes=")) item_attributes_file = value;
      else if (name.equals("--user-relations=")) user_relations_file = value;
      else if (name.equals("--item-relations=")) item_relations_file = value;
      else if (name.equals("--save-model=")) save_model_file = value;
      else if (name.equals("--load-model=")) load_model_file = value;
      else if (name.equals("--prediction-file=")) prediction_file = value;
      else if (name.equals("--prediction-line=")) prediction_line = value;
      else if (name.equals("--chronological-split=")) chronological_split = value;

      // Integer-valued options
      else if (name.equals("--find-iter=")) find_iter = Integer.parseInt(value);
      else if (name.equals("--max-iter=")) max_iter = Integer.parseInt(value);
      else if (name.equals("--random-seed=")) random_seed = Integer.parseInt(value);
      else if (name.equals("--cross-validation=")) cross_validation = Integer.parseInt(value);

      // Double-valued options
      else if (name.equals("--epsilon=")) epsilon = Double.parseDouble(value);
      else if (name.equals("--rmse-cutoff=")) rmse_cutoff = Double.parseDouble(value);
      else if (name.equals("--mae-cutoff=")) mae_cutoff = Double.parseDouble(value);
      else if (name.equals("--test-ratio=")) test_ratio = Double.parseDouble(value);

      // Enum options
      else if (name.equals("--rating-type=")) rating_type = RatingType.valueOf(value);
      else if (name.equals("--file-format=")) file_format = RatingFileFormat.valueOf(value);

      // Boolean options
      else if (name.equals("--compute-fit")) compute_fit = true;
      else if (name.equals("--online-evaluation")) online_eval = true;
      else if (name.equals("--show-fold-results")) show_fold_results = true;
      else if (name.equals("--search-hp")) search_hp = true;
      else if (name.equals("--help")) show_help = true;
      else if (name.equals("--version")) show_version = true;
    }
    // ... some more command line parameter actions ...
    boolean no_eval = true;
    if (test_ratio > 0 || test_file != null || chronological_split != null) no_eval = false;

    if (show_version) showVersion();

    if (show_help) usage(0);

    if (random_seed != -1) org.mymedialite.util.Random.initInstance(random_seed);

    // Set up recommender
    if (load_model_file != null) recommender = (RatingPredictor) Model.load(load_model_file);
    else if (method != null) recommender = Recommender.createRatingPredictor(method);
    else recommender = Recommender.createRatingPredictor("BiasedMatrixFactorization");

    // In case something went wrong ...
    if (recommender == null && method != null) usage("Unknown rating prediction method: " + method);
    if (recommender == null && load_model_file != null)
      usage("Could not load model from file " + load_model_file);

    checkParameters();

    try {
      recommender = Recommender.configure(recommender, recommender_options, new ErrorHandler());
    } catch (IllegalAccessException e) {
      System.err.println("Unable to instantiate recommender: " + recommender.toString());
      System.exit(0);
    }

    // ID mapping objects
    if (file_format == RatingFileFormat.KDDCUP_2011) {
      user_mapping = new IdentityMapping();
      item_mapping = new IdentityMapping();
    }

    // Load all the data
    loadData(
        data_dir,
        user_attributes_file,
        item_attributes_file,
        user_relations_file,
        item_relations_file,
        !online_eval);

    System.out.println(
        "Ratings range: " + recommender.getMinRating() + ", " + recommender.getMaxRating());

    if (test_ratio > 0) {
      RatingsSimpleSplit split = new RatingsSimpleSplit(training_data, test_ratio);
      // TODO check
      training_data = split.train().get(0);
      recommender.setRatings(training_data);
      // TODO check
      test_data = split.test().get(0);
      System.out.println("Test ratio: " + test_ratio);
    }

    if (chronological_split != null) {
      RatingsChronologicalSplit split =
          chronological_split_ratio != -1
              ? new RatingsChronologicalSplit(
                  (ITimedRatings) training_data, chronological_split_ratio)
              : new RatingsChronologicalSplit(
                  (ITimedRatings) training_data, chronological_split_time);
      training_data = split.train().get(0);
      recommender.setRatings(training_data);
      test_data = split.test().get(0);
      if (test_ratio != -1)
        System.out.println("Test ratio (chronological): " + chronological_split_ratio);
      else System.out.println("Split time:" + chronological_split_time);
    }

    System.out.print(
        Extensions.statistics(training_data, test_data, user_attributes, item_attributes, false));

    if (find_iter != 0) {
      if (!(recommender instanceof IIterativeModel))
        usage("Only iterative recommenders (interface IIterativeModel) support --find-iter=N.");

      System.out.println("Recommender: " + recommender.toString());

      if (cross_validation > 1) {
        RatingsCrossValidation.doIterativeCrossValidation(
            recommender, cross_validation, max_iter, find_iter);
      } else {
        IIterativeModel iterative_recommender = (IIterativeModel) recommender;

        if (load_model_file == null) recommender.train();

        if (compute_fit)
          System.out.println(
              "Fit "
                  + Ratings.evaluate(recommender, training_data)
                  + " iteration "
                  + iterative_recommender.getNumIter());

        System.out.println(
            Ratings.evaluate(recommender, test_data)
                + " iteration "
                + iterative_recommender.getNumIter());

        for (int it = iterative_recommender.getNumIter() + 1; it <= max_iter; it++) {
          long start = Calendar.getInstance().getTimeInMillis();
          iterative_recommender.iterate();
          training_time_stats.add(
              (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);

          if (it % find_iter == 0) {
            if (compute_fit) {
              start = Calendar.getInstance().getTimeInMillis();
              System.out.println(
                  "Fit " + Ratings.evaluate(recommender, training_data) + " iteration " + it);
              fit_time_stats.add(
                  (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
            }

            HashMap<String, Double> results = null;
            start = Calendar.getInstance().getTimeInMillis();
            results = Ratings.evaluate(recommender, test_data);
            eval_time_stats.add((double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
            rmse_eval_stats.add(results.get("RMSE"));
            System.out.println(results + " iteration " + it);

            Model.save(recommender, save_model_file, it);
            if (prediction_file != null)
              org.mymedialite.ratingprediction.Extensions.writePredictions(
                  recommender,
                  test_data,
                  prediction_file + "-it-" + it,
                  user_mapping,
                  item_mapping,
                  prediction_line);

            if (epsilon > 0.0 && results.get("RMSE") - Collections.min(rmse_eval_stats) > epsilon) {
              System.out.println(results.get("RMSE") + " >> " + Collections.min(rmse_eval_stats));
              System.out.println(
                  "Reached convergence on training/validation data after " + it + " iterations.");
              break;
            }
            if (results.get("RMSE") > rmse_cutoff || results.get("MAE") > mae_cutoff) {
              System.out.println("Reached cutoff after " + it + " iterations.");
              break;
            }
          }
        } // for
      }
    } else {
      long start = Calendar.getInstance().getTimeInMillis();

      System.out.println("Recommender: " + recommender);

      if (load_model_file == null) {
        if (cross_validation > 1) {
          RatingPredictionEvaluationResults results =
              RatingsCrossValidation.doCrossValidation(
                  recommender, cross_validation, compute_fit, show_fold_results);
          System.out.println(results);
          no_eval = true;
        } else {
          if (search_hp) {
            double result = NelderMead.findMinimum("RMSE", recommender);
            System.out.println("Estimated quality (on split): " + result);
          }

          recommender.train();
          System.out.println(
              "Training time: "
                  + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000
                  + " seconds");
        }
      }

      if (!no_eval) {
        start = Calendar.getInstance().getTimeInMillis();
        if (online_eval) System.out.println(RatingsOnline.evaluateOnline(recommender, test_data));
        else System.out.println(Ratings.evaluate(recommender, test_data));

        System.out.println(
            "Testing time: "
                + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000
                + " seconds");

        if (compute_fit) {
          System.out.print("Fit:");
          start = Calendar.getInstance().getTimeInMillis();
          System.out.print(Ratings.evaluate(recommender, training_data));
          System.out.println(
              " fit time: " + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
        }

        if (prediction_file != null) {
          System.out.print("Predict:");
          start = Calendar.getInstance().getTimeInMillis();
          org.mymedialite.ratingprediction.Extensions.writePredictions(
              recommender, test_data, prediction_file, user_mapping, item_mapping, prediction_line);
          System.out.println(
              " prediction time "
                  + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
        }
      }
      // System.out.println();
    }
    Model.save(recommender, save_model_file);
    // displayStats();

  }