コード例 #1
0
ファイル: KbPraDriver.java プロジェクト: bhushank/pra-oda
 /**
  * Reads from splitsDirectory and populates the data fields in builder. Returns true if we should
  * be doing cross validation, false otherwise.
  */
 public boolean initializeSplit(
     String splitsDirectory,
     String kbDirectory,
     String relation,
     PraConfig.Builder builder,
     DatasetFactory datasetFactory,
     FileUtil fileUtil)
     throws IOException {
   String fixed = relation.replace("/", "_");
   // We look in the splits directory for a fixed split; if we don't find
   // one, we do cross
   // validation.
   if (fileUtil.fileExists(splitsDirectory + fixed)) {
     String training = splitsDirectory + fixed + File.separator + "training.tsv";
     String testing = splitsDirectory + fixed + File.separator + "testing.tsv";
     builder.setTrainingData(datasetFactory.fromFile(training, builder.nodeDict));
     builder.setTestingData(datasetFactory.fromFile(testing, builder.nodeDict));
     return false;
   } else {
     // Bhushan, the seperator is _, not : for relations and categories
     fixed = fixed.replace(":", "_");
     builder.setAllData(
         datasetFactory.fromFile(
             kbDirectory + "relations" + File.separator + fixed, builder.nodeDict));
     String percent_training_file = splitsDirectory + "percent_training.tsv";
     builder.setPercentTraining(fileUtil.readDoubleListFromFile(percent_training_file).get(0));
     return true;
   }
 }
コード例 #2
0
ファイル: KbPraDriver.java プロジェクト: bhushank/pra-oda
 public void parseGraphFiles(String directory, PraConfig.Builder builder) throws IOException {
   directory = fileUtil.addDirectorySeparatorIfNecessary(directory);
   builder.setGraph(directory + "graph_chi" + File.separator + "edges.tsv");
   System.out.println("Loading node and edge dictionaries from graph directory: " + directory);
   BufferedReader reader = new BufferedReader(new FileReader(directory + "num_shards.tsv"));
   builder.setNumShards(Integer.parseInt(reader.readLine()));
   Dictionary nodeDict = new Dictionary();
   nodeDict.setFromFile(directory + "node_dict.tsv");
   builder.setNodeDictionary(nodeDict);
   Dictionary edgeDict = new Dictionary();
   edgeDict.setFromFile(directory + "edge_dict.tsv");
   builder.setEdgeDictionary(edgeDict);
 }
コード例 #3
0
ファイル: KbPraDriver.java プロジェクト: bhushank/pra-oda
  /**
   * Here we set up the PraConfig items that have to do with the input KB files. In particular, that
   * means deciding which relations are known to be inverses of each other, which edges should be
   * ignored because using them to predict new relations instances would consitute cheating, and
   * setting the range and domain of a relation to restrict new predictions.
   *
   * <p>Also, if the relations have been embedded into a latent space, we perform a mapping here
   * when deciding which edges to ignore. This means that each embedding of a KB graph has to have a
   * different directory.
   */
  public void parseKbFiles(
      String directory,
      String relation,
      PraConfig.Builder builder,
      String outputBase,
      FileUtil fileUtil)
      throws IOException {
    // TODO(matt): allow this to be left unspecified.
    Map<Integer, Integer> inverses = createInverses(directory + "inverses.tsv", builder.edgeDict);
    builder.setRelationInverses(inverses);

    Map<String, List<String>> embeddings = null;
    if (fileUtil.fileExists(directory + "embeddings.tsv")) {
      embeddings = fileUtil.readMapListFromTsvFile(directory + "embeddings.tsv");
    }
    List<Integer> unallowedEdges =
        createUnallowedEdges(relation, inverses, embeddings, builder.edgeDict);
    builder.setUnallowedEdges(unallowedEdges);

    if (fileUtil.fileExists(directory + "ranges.tsv")) {
      Map<String, String> ranges = fileUtil.readMapFromTsvFile(directory + "ranges.tsv");
      String range = ranges.get(relation);
      String fixed = range.replace("/", "_");
      String cat_file = directory + "category_instances" + File.separator + fixed;
      // Bhushan
      cat_file = cat_file.replace(":", "_");
      Set<Integer> allowedTargets = fileUtil.readIntegerSetFromFile(cat_file, builder.nodeDict);
      builder.setAllowedTargets(allowedTargets);
    } else {
      FileWriter writer =
          fileUtil.getFileWriter(outputBase + "settings.txt", true); // true -> append
      writer.write("No range file found! I hope your accept policy is as you want it...\n");
      System.out.println("No range file found!");
      writer.close();
    }
  }
コード例 #4
0
ファイル: KbPraDriver.java プロジェクト: bhushank/pra-oda
  public void runPra(
      String kbDirectory,
      String graphDirectory,
      String splitsDirectory,
      String parameterFile,
      String topK,
      String pathlength,
      String isOA,
      String outputBase)
      throws IOException, InterruptedException, ClassNotFoundException, Exception {

    outputBase = fileUtil.addDirectorySeparatorIfNecessary(outputBase);
    kbDirectory = fileUtil.addDirectorySeparatorIfNecessary(kbDirectory);
    graphDirectory = fileUtil.addDirectorySeparatorIfNecessary(graphDirectory);
    splitsDirectory = fileUtil.addDirectorySeparatorIfNecessary(splitsDirectory);

    fileUtil.mkdirOrDie(outputBase);
    boolean isOnlineAug = false;
    if (isOA.equalsIgnoreCase("yes")) isOnlineAug = true;

    KB kb = null;

    PraConfig baseConfig = null;
    PraConfig.Builder baseBuilder = null;
    if (isOnlineAug) {
      logger.info("Initializing SVO Graph");
      long initStart = System.currentTimeMillis();
      kb =
          OnlineAugment.init(
              kbDirectory,
              graphDirectory,
              splitsDirectory,
              outputBase,
              Integer.parseInt(topK),
              Integer.parseInt(pathlength));
      long initEnd = System.currentTimeMillis();
      logger.info("Initialization took " + (initEnd - initStart) / 1000.00 + " seconds");
      // ImportDriver svo = new ImportDriver();
      // svo.readSVOGraph();
      // kb.setSVONodeDict(svo.getSVONodeDict());
      // kb.setSVOAdjList(svo.getSVOAdjList());

    } else {
      baseBuilder = new PraConfig.Builder();
      parseGraphFiles(graphDirectory, baseBuilder);
      baseBuilder.setFromParamFile(fileUtil.getBufferedReader(parameterFile));

      // This call potentially uses the edge dictionary that's set in
      // parseGraphFiles - this MUST be
      // called after parseGraphFiles, or things will break with really
      // weird
      // errors. TODO(matt): I
      // really should write a test for this...
      Map<String, String> nodeNames = null;
      if (fileUtil.fileExists(kbDirectory + "node_names.tsv")) {
        nodeNames = fileUtil.readMapFromTsvFile(kbDirectory + "node_names.tsv", true);
      }
      Outputter outputter = new Outputter(baseBuilder.nodeDict, baseBuilder.edgeDict, nodeNames);
      baseBuilder.setOutputter(outputter);
      baseConfig = baseBuilder.build();
      Utils.deleteShards(graphDirectory + "graph_chi");
      GraphCreator gc = new GraphCreator(outputBase, false);
      gc.shardGraph(graphDirectory + "graph_chi/edges.tsv", 2);
    }
    long start = System.currentTimeMillis();
    FileWriter writer = fileUtil.getFileWriter(outputBase + "settings.txt");
    writer.write("KB used: " + kbDirectory + "\n");
    writer.write("Graph used: " + graphDirectory + "\n");
    writer.write("Splits used: " + splitsDirectory + "\n");
    writer.write("Parameter file used: " + parameterFile + "\n");
    writer.write("Parameters:\n");
    fileUtil.copyLines(fileUtil.getBufferedReader(parameterFile), writer);
    writer.write("End of parameters\n");
    writer.close();

    String relationsFile = splitsDirectory + "relations_to_run.tsv";

    String line;
    BufferedReader reader = fileUtil.getBufferedReader(relationsFile);
    while ((line = reader.readLine()) != null) {
      String relation = line;
      long startTrainTime = System.currentTimeMillis();
      if (isOnlineAug) {
        // *******************
        // Online Augmentation - Training Time
        logger.info("Augmenting during training time");
        // Augment during training time
        // The test code modifies the directory
        kb.setOutputDir(outputBase);

        kb = Corpus.startTrainAugmentation(kb, relation, true);

        // Bhushan, shard the graph. Num shards fixed at 2
        Utils.deleteShards(graphDirectory + "graph_chi");
        GraphCreator gc = new GraphCreator(outputBase, false);
        gc.shardGraph(graphDirectory + "graph_chi/edges.tsv", 2);
        // *******************
        /* Reread all the graph files */
        baseBuilder = new PraConfig.Builder();
        parseGraphFiles(graphDirectory, baseBuilder);
        baseBuilder.setFromParamFile(fileUtil.getBufferedReader(parameterFile));

        Map<String, String> nodeNames = null;
        if (fileUtil.fileExists(kbDirectory + "node_names.tsv")) {
          nodeNames = fileUtil.readMapFromTsvFile(kbDirectory + "node_names.tsv", true);
        }
        Outputter outputter = new Outputter(baseBuilder.nodeDict, baseBuilder.edgeDict, nodeNames);
        baseBuilder.setOutputter(outputter);
        baseConfig = baseBuilder.build();
        /* Finished Rereading the graph files */
      }

      PraConfig.Builder builder = new PraConfig.Builder(baseConfig);
      logger.info("\n\n\n\nRunning PRA for relation " + relation);
      boolean doCrossValidation = false;

      parseKbFiles(kbDirectory, relation, builder, outputBase, fileUtil);

      String outdir = fileUtil.addDirectorySeparatorIfNecessary(outputBase + relation);
      fileUtil.mkdirs(outdir);
      builder.setOutputBase(outdir);

      initializeSplit(
          splitsDirectory, kbDirectory, relation, builder, new DatasetFactory(), fileUtil);

      PraConfig config = builder.build();

      if (config.allData != null) {
        doCrossValidation = true;
      }

      // Run PRA
      if (doCrossValidation) {
        new PraTrainAndTester().crossValidate(config, kb, isOnlineAug, relation, startTrainTime);
      } else {
        new PraTrainAndTester().trainAndTest(config, kb, isOnlineAug, relation, startTrainTime);
      }
    }
    long end = System.currentTimeMillis();
    long millis = end - start;
    int seconds = (int) (millis / 1000);
    int minutes = seconds / 60;
    seconds = seconds - minutes * 60;
    BufferedWriter out = new BufferedWriter(new FileWriter(outputBase + "/timings.txt", true));
    out.write("Took " + minutes + " minutes and " + seconds + " seconds\n");
    out.flush();
    out.close();
    // kb.closeDB();
    System.out.println("Took " + minutes + " minutes and " + seconds + " seconds");
    writer.close();
  }