Ejemplo n.º 1
0
 @Override
 public double computeLoss() {
   return Ratings.evaluate(this, ratings).get("RMSE")
       + regU * Math.pow(org.mymedialite.datatype.VectorExtensions.euclideanNorm(userBiases), 2)
       + regI * Math.pow(org.mymedialite.datatype.VectorExtensions.euclideanNorm(itemBiases), 2);
 }
Ejemplo n.º 2
0
  public static void main(String[] args) throws Exception {

    // Handlers for uncaught exceptions and interrupts
    Thread.setDefaultUncaughtExceptionHandler(new Handlers());
    Runtime.getRuntime()
        .addShutdownHook(
            new Thread() {
              @Override
              public void run() {
                displayStats();
              }
            });

    // Recommender arguments
    String method = null;
    String recommender_options = "";

    // Help/version
    boolean show_help = false;
    boolean show_version = false;

    // Arguments for iteration search
    int max_iter = 100;
    double epsilon = 0;
    double rmse_cutoff = Double.MAX_VALUE;
    double mae_cutoff = Double.MAX_VALUE;

    // Data arguments
    String data_dir = "";

    // Other arguments
    boolean search_hp = false;
    int random_seed = -1;
    String prediction_line = "{0}\t{1}\t{2}";

    for (String arg : args) {
      int div = arg.indexOf("=") + 1;
      String name;
      String value;
      if (div > 0) {
        name = arg.substring(0, div);
        value = arg.substring(div);
      } else {
        name = arg;
        value = null;
      }

      // String-valued options
      if (name.equals("--training-file=")) training_file = value;
      else if (name.equals("--test-file=")) test_file = value;
      else if (name.equals("--recommender=")) method = value;
      else if (name.equals("--recommender-options=")) recommender_options += " " + value;
      else if (name.equals("--data-dir=")) data_dir = value;
      else if (name.equals("--user-attributes=")) user_attributes_file = value;
      else if (name.equals("--item-attributes=")) item_attributes_file = value;
      else if (name.equals("--user-relations=")) user_relations_file = value;
      else if (name.equals("--item-relations=")) item_relations_file = value;
      else if (name.equals("--save-model=")) save_model_file = value;
      else if (name.equals("--load-model=")) load_model_file = value;
      else if (name.equals("--prediction-file=")) prediction_file = value;
      else if (name.equals("--prediction-line=")) prediction_line = value;
      else if (name.equals("--chronological-split=")) chronological_split = value;

      // Integer-valued options
      else if (name.equals("--find-iter=")) find_iter = Integer.parseInt(value);
      else if (name.equals("--max-iter=")) max_iter = Integer.parseInt(value);
      else if (name.equals("--random-seed=")) random_seed = Integer.parseInt(value);
      else if (name.equals("--cross-validation=")) cross_validation = Integer.parseInt(value);

      // Double-valued options
      else if (name.equals("--epsilon=")) epsilon = Double.parseDouble(value);
      else if (name.equals("--rmse-cutoff=")) rmse_cutoff = Double.parseDouble(value);
      else if (name.equals("--mae-cutoff=")) mae_cutoff = Double.parseDouble(value);
      else if (name.equals("--test-ratio=")) test_ratio = Double.parseDouble(value);

      // Enum options
      else if (name.equals("--rating-type=")) rating_type = RatingType.valueOf(value);
      else if (name.equals("--file-format=")) file_format = RatingFileFormat.valueOf(value);

      // Boolean options
      else if (name.equals("--compute-fit")) compute_fit = true;
      else if (name.equals("--online-evaluation")) online_eval = true;
      else if (name.equals("--show-fold-results")) show_fold_results = true;
      else if (name.equals("--search-hp")) search_hp = true;
      else if (name.equals("--help")) show_help = true;
      else if (name.equals("--version")) show_version = true;
    }
    // ... some more command line parameter actions ...
    boolean no_eval = true;
    if (test_ratio > 0 || test_file != null || chronological_split != null) no_eval = false;

    if (show_version) showVersion();

    if (show_help) usage(0);

    if (random_seed != -1) org.mymedialite.util.Random.initInstance(random_seed);

    // Set up recommender
    if (load_model_file != null) recommender = (RatingPredictor) Model.load(load_model_file);
    else if (method != null) recommender = Recommender.createRatingPredictor(method);
    else recommender = Recommender.createRatingPredictor("BiasedMatrixFactorization");

    // In case something went wrong ...
    if (recommender == null && method != null) usage("Unknown rating prediction method: " + method);
    if (recommender == null && load_model_file != null)
      usage("Could not load model from file " + load_model_file);

    checkParameters();

    try {
      recommender = Recommender.configure(recommender, recommender_options, new ErrorHandler());
    } catch (IllegalAccessException e) {
      System.err.println("Unable to instantiate recommender: " + recommender.toString());
      System.exit(0);
    }

    // ID mapping objects
    if (file_format == RatingFileFormat.KDDCUP_2011) {
      user_mapping = new IdentityMapping();
      item_mapping = new IdentityMapping();
    }

    // Load all the data
    loadData(
        data_dir,
        user_attributes_file,
        item_attributes_file,
        user_relations_file,
        item_relations_file,
        !online_eval);

    System.out.println(
        "Ratings range: " + recommender.getMinRating() + ", " + recommender.getMaxRating());

    if (test_ratio > 0) {
      RatingsSimpleSplit split = new RatingsSimpleSplit(training_data, test_ratio);
      // TODO check
      training_data = split.train().get(0);
      recommender.setRatings(training_data);
      // TODO check
      test_data = split.test().get(0);
      System.out.println("Test ratio: " + test_ratio);
    }

    if (chronological_split != null) {
      RatingsChronologicalSplit split =
          chronological_split_ratio != -1
              ? new RatingsChronologicalSplit(
                  (ITimedRatings) training_data, chronological_split_ratio)
              : new RatingsChronologicalSplit(
                  (ITimedRatings) training_data, chronological_split_time);
      training_data = split.train().get(0);
      recommender.setRatings(training_data);
      test_data = split.test().get(0);
      if (test_ratio != -1)
        System.out.println("Test ratio (chronological): " + chronological_split_ratio);
      else System.out.println("Split time:" + chronological_split_time);
    }

    System.out.print(
        Extensions.statistics(training_data, test_data, user_attributes, item_attributes, false));

    if (find_iter != 0) {
      if (!(recommender instanceof IIterativeModel))
        usage("Only iterative recommenders (interface IIterativeModel) support --find-iter=N.");

      System.out.println("Recommender: " + recommender.toString());

      if (cross_validation > 1) {
        RatingsCrossValidation.doIterativeCrossValidation(
            recommender, cross_validation, max_iter, find_iter);
      } else {
        IIterativeModel iterative_recommender = (IIterativeModel) recommender;

        if (load_model_file == null) recommender.train();

        if (compute_fit)
          System.out.println(
              "Fit "
                  + Ratings.evaluate(recommender, training_data)
                  + " iteration "
                  + iterative_recommender.getNumIter());

        System.out.println(
            Ratings.evaluate(recommender, test_data)
                + " iteration "
                + iterative_recommender.getNumIter());

        for (int it = iterative_recommender.getNumIter() + 1; it <= max_iter; it++) {
          long start = Calendar.getInstance().getTimeInMillis();
          iterative_recommender.iterate();
          training_time_stats.add(
              (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);

          if (it % find_iter == 0) {
            if (compute_fit) {
              start = Calendar.getInstance().getTimeInMillis();
              System.out.println(
                  "Fit " + Ratings.evaluate(recommender, training_data) + " iteration " + it);
              fit_time_stats.add(
                  (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
            }

            HashMap<String, Double> results = null;
            start = Calendar.getInstance().getTimeInMillis();
            results = Ratings.evaluate(recommender, test_data);
            eval_time_stats.add((double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
            rmse_eval_stats.add(results.get("RMSE"));
            System.out.println(results + " iteration " + it);

            Model.save(recommender, save_model_file, it);
            if (prediction_file != null)
              org.mymedialite.ratingprediction.Extensions.writePredictions(
                  recommender,
                  test_data,
                  prediction_file + "-it-" + it,
                  user_mapping,
                  item_mapping,
                  prediction_line);

            if (epsilon > 0.0 && results.get("RMSE") - Collections.min(rmse_eval_stats) > epsilon) {
              System.out.println(results.get("RMSE") + " >> " + Collections.min(rmse_eval_stats));
              System.out.println(
                  "Reached convergence on training/validation data after " + it + " iterations.");
              break;
            }
            if (results.get("RMSE") > rmse_cutoff || results.get("MAE") > mae_cutoff) {
              System.out.println("Reached cutoff after " + it + " iterations.");
              break;
            }
          }
        } // for
      }
    } else {
      long start = Calendar.getInstance().getTimeInMillis();

      System.out.println("Recommender: " + recommender);

      if (load_model_file == null) {
        if (cross_validation > 1) {
          RatingPredictionEvaluationResults results =
              RatingsCrossValidation.doCrossValidation(
                  recommender, cross_validation, compute_fit, show_fold_results);
          System.out.println(results);
          no_eval = true;
        } else {
          if (search_hp) {
            double result = NelderMead.findMinimum("RMSE", recommender);
            System.out.println("Estimated quality (on split): " + result);
          }

          recommender.train();
          System.out.println(
              "Training time: "
                  + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000
                  + " seconds");
        }
      }

      if (!no_eval) {
        start = Calendar.getInstance().getTimeInMillis();
        if (online_eval) System.out.println(RatingsOnline.evaluateOnline(recommender, test_data));
        else System.out.println(Ratings.evaluate(recommender, test_data));

        System.out.println(
            "Testing time: "
                + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000
                + " seconds");

        if (compute_fit) {
          System.out.print("Fit:");
          start = Calendar.getInstance().getTimeInMillis();
          System.out.print(Ratings.evaluate(recommender, training_data));
          System.out.println(
              " fit time: " + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
        }

        if (prediction_file != null) {
          System.out.print("Predict:");
          start = Calendar.getInstance().getTimeInMillis();
          org.mymedialite.ratingprediction.Extensions.writePredictions(
              recommender, test_data, prediction_file, user_mapping, item_mapping, prediction_line);
          System.out.println(
              " prediction time "
                  + (double) (Calendar.getInstance().getTimeInMillis() - start) / 1000);
        }
      }
      // System.out.println();
    }
    Model.save(recommender, save_model_file);
    // displayStats();

  }