public LinearClassifier createLinearClassifier(double[] weights) { double[][] weights2D; if (objective != null) { weights2D = objective.to2D(weights); } else { weights2D = ArrayUtils.to2D(weights, featureIndex.size(), labelIndex.size()); } return new LinearClassifier<L, F>(weights2D, featureIndex, labelIndex); }
/** * Given the path to a file representing the text based serialization of a Linear Classifier, * reconstitutes and returns that LinearClassifier. * * <p>TODO: Leverage Index */ public static LinearClassifier<String, String> loadFromFilename(String file) { try { BufferedReader in = IOUtils.readerFromString(file); // Format: read indices first, weights, then thresholds Index<String> labelIndex = HashIndex.loadFromReader(in); Index<String> featureIndex = HashIndex.loadFromReader(in); double[][] weights = new double[featureIndex.size()][labelIndex.size()]; int currLine = 1; String line = in.readLine(); while (line != null && line.length() > 0) { String[] tuples = line.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER); if (tuples.length != 3) { throw new Exception( "Error: incorrect number of tokens in weight specifier, line=" + currLine + " in file " + file); } currLine++; int feature = Integer.valueOf(tuples[0]); int label = Integer.valueOf(tuples[1]); double value = Double.valueOf(tuples[2]); weights[feature][label] = value; line = in.readLine(); } // First line in thresholds is the number of thresholds int numThresholds = Integer.valueOf(in.readLine()); double[] thresholds = new double[numThresholds]; int curr = 0; while ((line = in.readLine()) != null) { double tval = Double.valueOf(line.trim()); thresholds[curr++] = tval; } in.close(); LinearClassifier<String, String> classifier = new LinearClassifier<String, String>(weights, featureIndex, labelIndex); return classifier; } catch (Exception e) { System.err.println("Error in LinearClassifierFactory, loading from file=" + file); e.printStackTrace(); return null; } }
public Classifier<L, F> trainClassifier(Iterable<Datum<L, F>> dataIterable) { Minimizer<DiffFunction> minimizer = getMinimizer(); Index<F> featureIndex = Generics.newIndex(); Index<L> labelIndex = Generics.newIndex(); for (Datum<L, F> d : dataIterable) { labelIndex.add(d.label()); featureIndex.addAll(d.asFeatures()); // If there are duplicates, it doesn't add them again. } System.err.println( String.format( "Training linear classifier with %d features and %d labels", featureIndex.size(), labelIndex.size())); LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(dataIterable, logPrior, featureIndex, labelIndex); objective.setPrior(new LogPrior(LogPrior.LogPriorType.QUADRATIC)); double[] initial = objective.initial(); double[] weights = minimizer.minimize(objective, TOL, initial); LinearClassifier<L, F> classifier = new LinearClassifier<L, F>(objective.to2D(weights), featureIndex, labelIndex); return classifier; }