public static List<DrugUser> make_from_NEP(int target_count, int rng_seed) throws Exception { HashMap<String, Object> population_params = load_defaults(null); // RawLoader rl = new RawLoader(population_params); // Instances hcv_learning_data = buildLearningInstances(rl.getLearningData()); // System.out.println(hcv_learning_data.toSummaryString()); // System.out.println(hcv_learning_data.toString()); // System.exit(1); // HashMap <String, Classifier> classifiers = train_classifiers(hcv_learning_data); PersonGenerator pg = PersonGenerator.make_NEP_generator( population_params, (Double) population_params.get("idu_maturity_threshold"), rng_seed); HashMap<String, Object> generator_params = new HashMap<String, Object>(); generator_params.put("early_idus_only", (Boolean) false); ArrayList<DrugUser> pop = new ArrayList<DrugUser>(); for (int idu_num = 0; idu_num < target_count; ++idu_num) { try { pop.add(pg.generate(generator_params)); System.out.print("."); } catch (Exception e) { System.out.println("x"); } } System.out.println("Synthetic population:" + pop.size()); if (pop.size() > 0) { int num_infected = 0; int num_abpos = 0; for (DrugUser idu : pop) { num_abpos += idu.isHcvABpos() ? 1 : 0; } System.out.println( System.lineSeparator() + "Initial HCV prevalence (AB): " + num_abpos / (1.0 * pop.size())); } return pop; }
public static void analyze_accuracy_NHBS(int rng_seed) throws Exception { HashMap<String, Object> population_params = load_defaults(null); RawLoader rl = new RawLoader(population_params, true, false, rng_seed); List<DrugUser> learningData = rl.getLearningData(); Instances nhbs_data = new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size()); for (DrugUser du : learningData) { nhbs_data.add(du.getInstance()); } System.out.println(nhbs_data.toSummaryString()); nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state")); // wishlist: remove infrequent values // weka.filters.unsupervised.instance.RemoveFrequentValues() Filter f1 = new RemoveUseless(); f1.setInputFormat(nhbs_data); nhbs_data = Filter.useFilter(nhbs_data, f1); System.out.println("NHBS IDU 2009 Dataset"); System.out.println("Summary of input:"); // System.out.printlnnhbs_data.toSummaryString()); System.out.println(" Num of classes: " + nhbs_data.numClasses()); System.out.println(" Num of attributes: " + nhbs_data.numAttributes()); for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) { Attribute attr = nhbs_data.attribute(idx); System.out.println("" + idx + ": " + attr.toString()); System.out.println(" distinct values:" + nhbs_data.numDistinctValues(idx)); // System.out.println("" + attr.enumerateValues()); } ArrayList<String> options = new ArrayList<String>(); options.add("-Q"); options.add("" + rng_seed); // System.exit(0); // nhbs_data.deleteAttributeAt(0); //response ID // nhbs_data.deleteAttributeAt(16); //zip // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00 // ROC=0.60 // Classifier classifier = new MINND(); // Classifier classifier = new CitationKNN(); // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7% // Classifier classifier = new SMOreg(); Classifier classifier = new Logistic(); // ROC=0.686 // Classifier classifier = new LinearNNSearch(); // LinearRegression: Cannot handle multi-valued nominal class! // Classifier classifier = new LinearRegression(); // Classifier classifier = new RandomForest(); // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree. generally, // might want to optimize (or not // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests) // options.add("-I"); options.add("100"); options.add("-K"); options.add("4"); // ROC=0.673 // KStar classifier = new KStar(); // classifier.setGlobalBlend(20); //the amount of not greedy, in percent // ROC=0.633 // Classifier classifier = new AdaBoostM1(); // ROC=0.66 // Classifier classifier = new MultiBoostAB(); // ROC=0.67 // Classifier classifier = new Stacking(); // ROC=0.495 // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier // ROC=0.585 // String[] options = new String[1]; // options[0] = "-U"; // unpruned tree // classifier.setOptions(options); // set the options classifier.setOptions((String[]) options.toArray(new String[0])); // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code // classifier.buildClassifier(nhbs_data); // build classifier // evaluation Evaluation eval = new Evaluation(nhbs_data); eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation System.out.println(eval.toSummaryString("\nResults\n\n", false)); System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toCumulativeMarginDistributionString()); }