示例#1
0
  /**
   * Main function
   *
   * @param args Arguments to the main method
   * @throws FileNotFoundException
   * @throws IOException
   */
  public static void main(String[] args) throws FileNotFoundException, IOException {

    if (args.length <= 0) {
      System.err.println("No parameters file");
      System.exit(1);
    }

    SetupParameters global = new SetupParameters();
    global.LoadParameters(args[0]);

    OpenDataset train = new OpenDataset();
    OpenDataset test = null;
    OpenDataset validation = null;

    boolean isTrain;
    train.processClassifierDataset(global.train_file, true);

    global.n_test_patterns = 0;
    global.n_train_patterns = train.getndatos();
    if (global.test_data) {
      test = new OpenDataset();
      test.processClassifierDataset(global.test_file, false);
      global.n_test_patterns = test.getndatos();
    }
    global.n_val_patterns = 0;
    if (global.val_data) {
      validation = new OpenDataset();
      validation.processClassifierDataset(global.val_file, false);
      global.n_val_patterns = validation.getndatos();
    }

    // Assign data and parameters to internal variables
    // Number of inputs
    global.Ninputs = 0;
    for (int i = 0; i < train.getnentradas(); i++) {
      if (train.getTiposAt(i) == 0) {
        Vector in_values = train.getRangosVar(i);
        global.Ninputs += in_values.size();
      } else {
        global.Ninputs++;
      }
    }

    // Number of outputs
    if (train.getTiposAt(train.getnentradas()) != 0) {
      global.Noutputs = train.getnsalidas();
    } else {
      Vector out_values = train.getRangosVar(train.getnentradas());

      global.Noutputs = out_values.size();
    }

    Data data =
        new Data(
            global.Ninputs + global.Noutputs, global.n_train_patterns, global.n_test_patterns, 0);

    Genesis.DatasetToArray(data.train, train);
    if (global.test_data) {
      Genesis.DatasetToArray(data.test, test);
    }
    if (global.val_data) {
      Genesis.DatasetToArray(data.validation, validation);
    }

    if (global.tipify_inputs == true) {
      double mean, sigma, sq_sum; /* Tipify input data. */

      /* Scale input. */
      for (int i = 0; i < global.Ninputs; i++) {
        /* Get the mean and variance. */
        mean = sigma = sq_sum = 0.;

        for (int j = 0; j < global.n_train_patterns; j++) {
          mean += data.train[j][i];
          sq_sum += data.train[j][i] * data.train[j][i];
        }

        mean /= global.n_train_patterns;
        sigma = Math.sqrt(sq_sum / global.n_train_patterns - mean * mean);

        /* Tipify: z = (x - mean)/std. dev. */
        /* If std. dev. is 0 do nothing. */
        if (sigma > 0.000001) {
          for (int j = 0; j < global.n_train_patterns; j++) {
            data.train[j][i] = (data.train[j][i] - mean) / sigma;
          }

          for (int j = 0; j < global.n_test_patterns; j++) {
            data.test[j][i] = (data.test[j][i] - mean) / sigma;
          }
        }
      }
    }

    sonn SelfOrganizingNetwork = new sonn(global, data);

    SelfOrganizingNetwork.SaveNetwork("SONN_Network", global.seed, false);

    if (global.problem.compareToIgnoreCase("Classification") == 0) {
      double result =
          SelfOrganizingNetwork.TestSONNInClassification(
              global, data.train, global.n_train_patterns);
      System.out.print("Train accuracy: " + result + "\t");
      result =
          SelfOrganizingNetwork.TestSONNInClassification(global, data.test, global.n_test_patterns);
      System.out.println("Test accuracy: " + result);
    } else {
      double result =
          SelfOrganizingNetwork.TestSONNInRegression(global, data.train, global.n_train_patterns);
      System.out.print("Train accuracy: " + result + "\t");
      result =
          SelfOrganizingNetwork.TestSONNInRegression(global, data.test, global.n_test_patterns);
      System.out.println("Test accuracy: " + result);
    }

    SelfOrganizingNetwork.SaveOutputFile(
        global.train_output, data.train, global.n_train_patterns, global);
    SelfOrganizingNetwork.SaveOutputFile(
        global.test_output, data.test, global.n_test_patterns, global);
  }
示例#2
0
  /**
   * Constructor that takes only the setup parameters (NOT USED)
   *
   * @param global Global Definition parameters
   * @throws FileNotFoundException
   * @throws IOException
   */
  public Data(SetupParameters global) throws FileNotFoundException, IOException {
    String line;
    int pos1, pos2;

    try {
      // Training data
      FileInputStream file = new FileInputStream(global.train_file);
      BufferedReader f = new BufferedReader(new InputStreamReader(file));

      // Number of patterns
      line = f.readLine();
      global.n_train_patterns = Integer.parseInt(line);

      // Number of inputs
      line = f.readLine();
      global.Ninputs = Integer.parseInt(line);

      // Number of outputs
      line = f.readLine();
      global.Noutputs = Integer.parseInt(line);

      // Read data
      train = new double[global.n_train_patterns][global.Ninputs + global.Noutputs];

      for (int i = 0; i < global.n_train_patterns; i++) {
        line = f.readLine();
        pos1 = 0;
        for (int j = 0; j < global.Ninputs + global.Noutputs - 1; j++) {
          pos2 = line.indexOf(" ", pos1);
          train[i][j] = Double.parseDouble(line.substring(pos1, pos2));
          pos1 = pos2 + 1;
        }
        train[i][global.Ninputs + global.Noutputs - 1] = Double.parseDouble(line.substring(pos1));
      }

      file.close();
    } catch (FileNotFoundException e) {
      System.err.println("Training file does not exist");
      System.exit(-1);
    }

    if (global.test_data) {
      try {
        // Training data
        FileInputStream file = new FileInputStream(global.test_file);
        BufferedReader f = new BufferedReader(new InputStreamReader(file));

        // Number of patterns
        line = f.readLine();
        global.n_test_patterns = Integer.parseInt(line);

        // Number of inputs
        line = f.readLine();
        global.Ninputs = Integer.parseInt(line);

        // Number of outputs
        line = f.readLine();
        global.Noutputs = Integer.parseInt(line);

        // Read data
        test = new double[global.n_test_patterns][global.Ninputs + global.Noutputs];

        for (int i = 0; i < global.n_test_patterns; i++) {
          line = f.readLine();
          pos1 = 0;
          for (int j = 0; j < global.Ninputs + global.Noutputs - 1; j++) {
            pos2 = line.indexOf(" ", pos1);
            test[i][j] = Double.parseDouble(line.substring(pos1, pos2));
            pos1 = pos2 + 1;
          }
          test[i][global.Ninputs + global.Noutputs - 1] = Double.parseDouble(line.substring(pos1));
        }

        file.close();
      } catch (FileNotFoundException f) {
        System.err.println("Testing file does not exist");
        System.exit(-1);
      }
    }

    if (global.val_data) {
      try {
        // Training data
        FileInputStream file = new FileInputStream(global.val_file);
        BufferedReader f = new BufferedReader(new InputStreamReader(file));

        // Number of patterns
        line = f.readLine();
        global.n_val_patterns = Integer.parseInt(line);

        // Number of inputs
        line = f.readLine();
        global.Ninputs = Integer.parseInt(line);

        // Number of outputs
        line = f.readLine();
        global.Noutputs = Integer.parseInt(line);
        global.Nhidden[global.Nhidden_layers] = global.Noutputs;

        // Read data
        validation = new double[global.n_val_patterns][global.Ninputs + global.Noutputs];

        for (int i = 0; i < global.n_val_patterns; i++) {
          line = f.readLine();
          pos1 = 0;
          for (int j = 0; j < global.Ninputs + global.Noutputs - 1; j++) {
            pos2 = line.indexOf(" ", pos1);
            validation[i][j] = Double.parseDouble(line.substring(pos1, pos2));
            pos1 = pos2 + 1;
          }
          validation[i][global.Ninputs + global.Noutputs - 1] =
              Double.parseDouble(line.substring(pos1));
        }

        file.close();
      } catch (FileNotFoundException e) {
        System.err.println("Validation file does not exist");
        System.exit(-1);
      }
    }
  }