예제 #1
0
  /**
   * Load ensemble from file_name
   *
   * @param file_name File name
   */
  public void LoadEnsemble(String file_name) {
    String name;

    // Load networks
    for (int i = 0; i < Nnetworks; i++) {
      name = file_name + "_net_" + i;
      nets[i].LoadNetwork(name);
    }

    // Load weigths
    try {
      FileInputStream file = new FileInputStream(file_name);
      DataInputStream dataIn = new DataInputStream(file);

      // Load weights
      for (int i = 0; i < Noutputs; i++) {
        for (int j = 0; j < Nnetworks; j++) {
          weights[i][j] = dataIn.readDouble();
        }
      }

      dataIn.close();
    } catch (FileNotFoundException ex) {
      System.err.println("Unable to open ensemble files");
      System.exit(1);
    } catch (IOException ex) {
      System.err.println("IO exception");
      System.exit(1);
    }
  }
예제 #2
0
  /**
   * Train Ensemble
   *
   * @param global Global definition parameters
   * @param data Input data
   */
  public void TrainEnsemble(EnsembleParameters global, Data data) {

    // Test type of sampling
    if (global.sampling.compareToIgnoreCase("None") == 0) {
      TrainEnsembleNoSampling(global, data);
    } else if (global.sampling.compareToIgnoreCase("Bagging") == 0) {
      TrainEnsembleBagging(global, data);
    } else if (global.sampling.compareToIgnoreCase("Arcing") == 0) {
      TrainEnsembleArcing(global, data);
    } else if (global.sampling.compareToIgnoreCase("Ada") == 0) {
      TrainEnsembleAda(global, data);
    } else {
      System.err.println("Invalid sampling method");
      System.exit(1);
    }
  }
예제 #3
0
  /**
   * Save data in output file
   *
   * @param file_name File name
   * @param data Data to be saved
   * @param n No of patterns
   * @param problem Type of problem (CLASSIFICATION | REGRESSION)
   * @throws IOException
   */
  public void SaveOutputFile(
      String file_name, double data[][], int n, String problem, double[] a, double[] b) {
    String line;
    double outputs[] = new double[Noutputs];

    try {
      // Result file
      FileOutputStream file = new FileOutputStream(file_name);
      BufferedWriter f = new BufferedWriter(new OutputStreamWriter(file));

      // File header
      f.write("@relation " + Attributes.getRelationName() + "\n");
      f.write(Attributes.getInputAttributesHeader());
      f.write(Attributes.getOutputAttributesHeader());
      f.write(Attributes.getInputHeader() + "\n");
      f.write(Attributes.getOutputHeader() + "\n");
      f.write("@data\n");

      // For all patterns
      for (int i = 0; i < n; i++) {

        // Classification
        if (problem.compareToIgnoreCase("Classification") == 0) {
          // Obtain class
          int Class = 0;
          for (int j = 1; j < Noutputs; j++) {
            if (data[i][Class + Ninputs] < data[i][j + Ninputs]) {
              Class = j;
            }
          }
          /*
          f.write(Integer.toString(Class) + " ");
          f.write(Integer.toString(EnsembleGetClassOfPattern(data[i])));
                              f.newLine();
          */
          f.write(Attributes.getOutputAttributes()[0].getNominalValue(Class) + " ");
          f.write(
              Attributes.getOutputAttributes()[0].getNominalValue(
                  EnsembleGetClassOfPattern(data[i])));
          f.newLine();
          f.flush();

        }
        // Regression
        else {
          if (a != null && b != null) {
            for (int j = 0; j < Noutputs; j++) {
              f.write(Double.toString((data[i][Ninputs + j] - b[j]) / a[j]) + " ");
            }
            EnsembleOutput(data[i], outputs);
            for (int j = 0; j < Noutputs; j++) {
              f.write(Double.toString((outputs[j] - b[j]) / a[j]) + " ");
            }
            f.newLine();
          } else {
            for (int j = 0; j < Noutputs; j++) {
              f.write(Double.toString(data[i][Ninputs + j]) + " ");
            }
            EnsembleOutput(data[i], outputs);
            for (int j = 0; j < Noutputs; j++) {
              f.write(Double.toString(outputs[j]) + " ");
            }
            f.newLine();
          }
        }
      }
      f.close();
      file.close();
    } catch (FileNotFoundException e) {
      System.err.println("Training file does not exist");
      System.exit(1);
    } catch (IOException e) {
      e.printStackTrace();
      System.exit(-1);
    }
  }