/**
   * Gets the current settings of the Classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String[] getOptions() {

    String[] superOptions;

    if (m_InitOptions != null) {
      try {
        ((OptionHandler) m_Classifier).setOptions((String[]) m_InitOptions.clone());
        superOptions = super.getOptions();
        ((OptionHandler) m_Classifier).setOptions((String[]) m_BestClassifierOptions.clone());
      } catch (Exception e) {
        throw new RuntimeException(
            "CVParameterSelection: could not set options " + "in getOptions().");
      }
    } else {
      superOptions = super.getOptions();
    }
    String[] options = new String[superOptions.length + m_CVParams.size() * 2 + 2];

    int current = 0;
    for (int i = 0; i < m_CVParams.size(); i++) {
      options[current++] = "-P";
      options[current++] = "" + getCVParameter(i);
    }
    options[current++] = "-X";
    options[current++] = "" + getNumFolds();

    System.arraycopy(superOptions, 0, options, current, superOptions.length);

    return options;
  }
예제 #2
0
 /**
  * Parses a given list of options.
  *
  * <p>
  * <!-- options-start -->
  * <!-- options-end -->
  */
 @Override
 public void setOptions(String[] options) throws Exception {
   if (m_delegate == null) {
     init();
   }
   ((OptionHandler) m_delegate).setOptions(options);
 }
  /**
   * Finds the best parameter combination. (recursive for each parameter being optimised).
   *
   * @param depth the index of the parameter to be optimised at this level
   * @param trainData the data the search is based on
   * @param random a random number generator
   * @throws Exception if an error occurs
   */
  protected void findParamsByCrossValidation(int depth, Instances trainData, Random random)
      throws Exception {

    if (depth < m_CVParams.size()) {
      CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth);

      double upper;
      switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
        case 1:
          upper = m_NumAttributes;
          break;
        case 2:
          upper = m_TrainFoldSize;
          break;
        default:
          upper = cvParam.m_Upper;
          break;
      }
      double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
      for (cvParam.m_ParamValue = cvParam.m_Lower;
          cvParam.m_ParamValue <= upper;
          cvParam.m_ParamValue += increment) {
        findParamsByCrossValidation(depth + 1, trainData, random);
      }
    } else {

      Evaluation evaluation = new Evaluation(trainData);

      // Set the classifier options
      String[] options = createOptions();
      if (m_Debug) {
        System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":");
        for (int i = 0; i < options.length; i++) {
          System.err.print(" " + options[i]);
        }
        System.err.println("");
      }
      ((OptionHandler) m_Classifier).setOptions(options);
      for (int j = 0; j < m_NumFolds; j++) {

        // We want to randomize the data the same way for every
        // learning scheme.
        Instances train = trainData.trainCV(m_NumFolds, j, new Random(1));
        Instances test = trainData.testCV(m_NumFolds, j);
        m_Classifier.buildClassifier(train);
        evaluation.setPriors(train);
        evaluation.evaluateModel(m_Classifier, test);
      }
      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4));
      }
      if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {

        m_BestPerformance = error;
        m_BestClassifierOptions = createOptions();
      }
    }
  }
  /** tests the commandline operation of the saver. */
  public void testSaverCommandlineArgs() {
    String[] options;

    options = getCommandlineOptions(false);

    try {
      ((OptionHandler) m_Saver).setOptions(options);
    } catch (Exception e) {
      e.printStackTrace();
      fail("Command line test failed ('" + Utils.arrayToString(options) + "'): " + e.toString());
    }
  }
예제 #5
0
  /** tries to initialize the groovy object and set its options. */
  protected void initGroovyObject() {
    try {
      if (m_GroovyModule.isFile())
        m_GroovyObject = (Classifier) Groovy.newInstance(m_GroovyModule, Classifier.class);
      else m_GroovyObject = null;

      if (m_GroovyObject != null)
        ((OptionHandler) m_GroovyObject).setOptions(m_GroovyOptions.clone());
    } catch (Exception e) {
      m_GroovyObject = null;
      e.printStackTrace();
    }
  }
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();

    if (!(m_Classifier instanceof OptionHandler)) {
      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
    }
    m_InitOptions = ((OptionHandler) m_Classifier).getOptions();
    m_BestPerformance = -99;
    m_NumAttributes = trainData.numAttributes();
    Random random = new Random(m_Seed);
    trainData.randomize(random);
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();

    // Check whether there are any parameters to optimize
    if (m_CVParams.size() == 0) {
      m_Classifier.buildClassifier(trainData);
      m_BestClassifierOptions = m_InitOptions;
      return;
    }

    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestClassifierOptions = null;

    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData, random);

    String[] options = (String[]) m_BestClassifierOptions.clone();
    ((OptionHandler) m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }
예제 #7
0
  /**
   * This initializes the WEKA filter from class parameters to be used for the given data and the
   * given indexes.
   *
   * @param data the data to be filtered (later)
   * @param indexes the attribute indexes for the application of this filter
   * @return the initialized filter object
   */
  private Filter initFilter(Instances data) throws TaskException {

    Filter filter;

    try {
      Class filterClass = Class.forName(this.filterName);
      Constructor filterConstructor = filterClass.getConstructor();
      filter = (Filter) filterConstructor.newInstance();

      if (filter instanceof OptionHandler) {
        ((OptionHandler) filter).setOptions(StringUtils.getWekaOptions(this.parameters));
      }
      filter.setInputFormat(data); // first, parameters must be set, then the input format
    } catch (Exception e) {
      e.printStackTrace();
      throw new TaskException(
          TaskException.ERR_INVALID_PARAMS,
          this.id,
          "Filter class not found or" + " invalid: " + this.filterName);
    }

    return filter;
  }
  /**
   * Parses a given list of options.
   *
   * <p>
   * <!-- options-start -->
   * Valid options are:
   *
   * <p>
   *
   * <pre>
   * -W
   *  Use word frequencies instead of binary bag of words.
   * </pre>
   *
   * <pre>
   * -P &lt;# instances&gt;
   *  How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
   * </pre>
   *
   * <pre>
   * -M &lt;double&gt;
   *  Minimum word frequency. Words with less than this frequence are ignored.
   *  If periodic pruning is turned on then this is also used to determine which
   *  words to remove from the dictionary (default = 3).
   * </pre>
   *
   * <pre>
   * -normalize
   *  Normalize document length (use in conjunction with -norm and -lnorm)
   * </pre>
   *
   * <pre>
   * -norm &lt;num&gt;
   *  Specify the norm that each instance must have (default 1.0)
   * </pre>
   *
   * <pre>
   * -lnorm &lt;num&gt;
   *  Specify L-norm to use (default 2.0)
   * </pre>
   *
   * <pre>
   * -lowercase
   *  Convert all tokens to lowercase before adding to the dictionary.
   * </pre>
   *
   * <pre>
   * -stoplist
   *  Ignore words that are in the stoplist.
   * </pre>
   *
   * <pre>
   * -stopwords &lt;file&gt;
   *  A file containing stopwords to override the default ones.
   *  Using this option automatically sets the flag ('-stoplist') to use the
   *  stoplist if the file exists.
   *  Format: one stopword per line, lines starting with '#'
   *  are interpreted as comments and ignored.
   * </pre>
   *
   * <pre>
   * -tokenizer &lt;spec&gt;
   *  The tokenizing algorihtm (classname plus parameters) to use.
   *  (default: weka.core.tokenizers.WordTokenizer)
   * </pre>
   *
   * <pre>
   * -stemmer &lt;spec&gt;
   *  The stemmering algorihtm (classname plus parameters) to use.
   * </pre>
   *
   * <!-- options-end -->
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  @Override
  public void setOptions(String[] options) throws Exception {
    reset();

    super.setOptions(options);

    setUseWordFrequencies(Utils.getFlag("W", options));

    String pruneFreqS = Utils.getOption("P", options);
    if (pruneFreqS.length() > 0) {
      setPeriodicPruning(Integer.parseInt(pruneFreqS));
    }
    String minFreq = Utils.getOption("M", options);
    if (minFreq.length() > 0) {
      setMinWordFrequency(Double.parseDouble(minFreq));
    }

    setNormalizeDocLength(Utils.getFlag("normalize", options));

    String normFreqS = Utils.getOption("norm", options);
    if (normFreqS.length() > 0) {
      setNorm(Double.parseDouble(normFreqS));
    }
    String lnormFreqS = Utils.getOption("lnorm", options);
    if (lnormFreqS.length() > 0) {
      setLNorm(Double.parseDouble(lnormFreqS));
    }

    setLowercaseTokens(Utils.getFlag("lowercase", options));
    setUseStopList(Utils.getFlag("stoplist", options));

    String stopwordsS = Utils.getOption("stopwords", options);
    if (stopwordsS.length() > 0) {
      setStopwords(new File(stopwordsS));
    } else {
      setStopwords(null);
    }

    String tokenizerString = Utils.getOption("tokenizer", options);
    if (tokenizerString.length() == 0) {
      setTokenizer(new WordTokenizer());
    } else {
      String[] tokenizerSpec = Utils.splitOptions(tokenizerString);
      if (tokenizerSpec.length == 0) {
        throw new Exception("Invalid tokenizer specification string");
      }
      String tokenizerName = tokenizerSpec[0];
      tokenizerSpec[0] = "";
      Tokenizer tokenizer = (Tokenizer) Class.forName(tokenizerName).newInstance();
      if (tokenizer instanceof OptionHandler) {
        ((OptionHandler) tokenizer).setOptions(tokenizerSpec);
      }
      setTokenizer(tokenizer);
    }

    String stemmerString = Utils.getOption("stemmer", options);
    if (stemmerString.length() == 0) {
      setStemmer(null);
    } else {
      String[] stemmerSpec = Utils.splitOptions(stemmerString);
      if (stemmerSpec.length == 0) {
        throw new Exception("Invalid stemmer specification string");
      }
      String stemmerName = stemmerSpec[0];
      stemmerSpec[0] = "";
      Stemmer stemmer = (Stemmer) Class.forName(stemmerName).newInstance();
      if (stemmer instanceof OptionHandler) {
        ((OptionHandler) stemmer).setOptions(stemmerSpec);
      }
      setStemmer(stemmer);
    }

    Utils.checkForRemainingOptions(options);
  }
예제 #9
0
 /**
  * sets the filter to use
  *
  * @param name the classname of the filter
  * @param options the options for the filter
  */
 public void setFilter(String name, String[] options) throws Exception {
   m_Filter = (Filter) Class.forName(name).newInstance();
   if (m_Filter instanceof OptionHandler) ((OptionHandler) m_Filter).setOptions(options);
 }
예제 #10
0
  /**
   * Method for testing filters.
   *
   * @param filter the filter to use
   * @param options should contain the following arguments: <br>
   *     -i input_file <br>
   *     -o output_file <br>
   *     -c class_index <br>
   *     -z classname (for filters implementing weka.filters.Sourcable) <br>
   *     or -h for help on options
   * @throws Exception if something goes wrong or the user requests help on command options
   */
  public static void filterFile(Filter filter, String[] options) throws Exception {

    boolean debug = false;
    Instances data = null;
    DataSource input = null;
    PrintWriter output = null;
    boolean helpRequest;
    String sourceCode = "";

    try {
      helpRequest = Utils.getFlag('h', options);

      if (Utils.getFlag('d', options)) {
        debug = true;
      }
      String infileName = Utils.getOption('i', options);
      String outfileName = Utils.getOption('o', options);
      String classIndex = Utils.getOption('c', options);
      if (filter instanceof Sourcable) sourceCode = Utils.getOption('z', options);

      if (filter instanceof OptionHandler) {
        ((OptionHandler) filter).setOptions(options);
      }

      Utils.checkForRemainingOptions(options);
      if (helpRequest) {
        throw new Exception("Help requested.\n");
      }
      if (infileName.length() != 0) {
        input = new DataSource(infileName);
      } else {
        input = new DataSource(System.in);
      }
      if (outfileName.length() != 0) {
        output = new PrintWriter(new FileOutputStream(outfileName));
      } else {
        output = new PrintWriter(System.out);
      }

      data = input.getStructure();
      if (classIndex.length() != 0) {
        if (classIndex.equals("first")) {
          data.setClassIndex(0);
        } else if (classIndex.equals("last")) {
          data.setClassIndex(data.numAttributes() - 1);
        } else {
          data.setClassIndex(Integer.parseInt(classIndex) - 1);
        }
      }
    } catch (Exception ex) {
      String filterOptions = "";
      // Output the error and also the valid options
      if (filter instanceof OptionHandler) {
        filterOptions += "\nFilter options:\n\n";
        Enumeration enu = ((OptionHandler) filter).listOptions();
        while (enu.hasMoreElements()) {
          Option option = (Option) enu.nextElement();
          filterOptions += option.synopsis() + '\n' + option.description() + "\n";
        }
      }

      String genericOptions =
          "\nGeneral options:\n\n"
              + "-h\n"
              + "\tGet help on available options.\n"
              + "\t(use -b -h for help on batch mode.)\n"
              + "-i <file>\n"
              + "\tThe name of the file containing input instances.\n"
              + "\tIf not supplied then instances will be read from stdin.\n"
              + "-o <file>\n"
              + "\tThe name of the file output instances will be written to.\n"
              + "\tIf not supplied then instances will be written to stdout.\n"
              + "-c <class index>\n"
              + "\tThe number of the attribute to use as the class.\n"
              + "\t\"first\" and \"last\" are also valid entries.\n"
              + "\tIf not supplied then no class is assigned.\n";

      if (filter instanceof Sourcable) {
        genericOptions +=
            "-z <class name>\n" + "\tOutputs the source code representing the trained filter.\n";
      }

      throw new Exception('\n' + ex.getMessage() + filterOptions + genericOptions);
    }

    if (debug) {
      System.err.println("Setting input format");
    }
    boolean printedHeader = false;
    if (filter.setInputFormat(data)) {
      if (debug) {
        System.err.println("Getting output format");
      }
      output.println(filter.getOutputFormat().toString());
      printedHeader = true;
    }

    // Pass all the instances to the filter
    Instance inst;
    while (input.hasMoreElements(data)) {
      inst = input.nextElement(data);
      if (debug) {
        System.err.println("Input instance to filter");
      }
      if (filter.input(inst)) {
        if (debug) {
          System.err.println("Filter said collect immediately");
        }
        if (!printedHeader) {
          throw new Error("Filter didn't return true from setInputFormat() " + "earlier!");
        }
        if (debug) {
          System.err.println("Getting output instance");
        }
        output.println(filter.output().toString());
      }
    }

    // Say that input has finished, and print any pending output instances
    if (debug) {
      System.err.println("Setting end of batch");
    }
    if (filter.batchFinished()) {
      if (debug) {
        System.err.println("Filter said collect output");
      }
      if (!printedHeader) {
        if (debug) {
          System.err.println("Getting output format");
        }
        output.println(filter.getOutputFormat().toString());
      }
      if (debug) {
        System.err.println("Getting output instance");
      }
      while (filter.numPendingOutput() > 0) {
        output.println(filter.output().toString());
        if (debug) {
          System.err.println("Getting output instance");
        }
      }
    }
    if (debug) {
      System.err.println("Done");
    }

    if (output != null) {
      output.close();
    }

    if (sourceCode.length() != 0)
      System.out.println(
          wekaStaticWrapper((Sourcable) filter, sourceCode, data, filter.getOutputFormat()));
  }