예제 #1
0
  public static List<DrugUser> make_from_NEP(int target_count, int rng_seed) throws Exception {
    HashMap<String, Object> population_params = load_defaults(null);
    // RawLoader rl = new RawLoader(population_params);
    // Instances hcv_learning_data 	= buildLearningInstances(rl.getLearningData());
    // System.out.println(hcv_learning_data.toSummaryString());
    // System.out.println(hcv_learning_data.toString());
    // System.exit(1);
    // HashMap <String, Classifier> classifiers   = train_classifiers(hcv_learning_data);
    PersonGenerator pg =
        PersonGenerator.make_NEP_generator(
            population_params, (Double) population_params.get("idu_maturity_threshold"), rng_seed);

    HashMap<String, Object> generator_params = new HashMap<String, Object>();
    generator_params.put("early_idus_only", (Boolean) false);
    ArrayList<DrugUser> pop = new ArrayList<DrugUser>();
    for (int idu_num = 0; idu_num < target_count; ++idu_num) {
      try {
        pop.add(pg.generate(generator_params));
        System.out.print(".");
      } catch (Exception e) {
        System.out.println("x");
      }
    }
    System.out.println("Synthetic population:" + pop.size());
    if (pop.size() > 0) {
      int num_infected = 0;
      int num_abpos = 0;
      for (DrugUser idu : pop) {
        num_abpos += idu.isHcvABpos() ? 1 : 0;
      }
      System.out.println(
          System.lineSeparator()
              + "Initial HCV prevalence (AB):  "
              + num_abpos / (1.0 * pop.size()));
    }
    return pop;
  }
예제 #2
0
  public static void analyze_accuracy_NHBS(int rng_seed) throws Exception {
    HashMap<String, Object> population_params = load_defaults(null);
    RawLoader rl = new RawLoader(population_params, true, false, rng_seed);
    List<DrugUser> learningData = rl.getLearningData();

    Instances nhbs_data =
        new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size());
    for (DrugUser du : learningData) {
      nhbs_data.add(du.getInstance());
    }
    System.out.println(nhbs_data.toSummaryString());
    nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state"));

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    ArrayList<String> options = new ArrayList<String>();
    options.add("-Q");
    options.add("" + rng_seed);
    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // ROC=0.60
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    Classifier classifier = new Logistic();
    // ROC=0.686
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    // Classifier classifier = new RandomForest();
    // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree.  generally,
    // might want to optimize (or not
    // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    // options.add("-I"); options.add("100"); options.add("-K"); options.add("4");
    // ROC=0.673

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent
    // ROC=0.633

    // Classifier classifier = new AdaBoostM1();
    // ROC=0.66
    // Classifier classifier = new MultiBoostAB();
    // ROC=0.67
    // Classifier classifier = new Stacking();
    // ROC=0.495

    // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier
    // ROC=0.585
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options

    classifier.setOptions((String[]) options.toArray(new String[0]));

    // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code
    // classifier.buildClassifier(nhbs_data); // build classifier

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }