Ejemplo n.º 1
0
  public static void main(String[] args) throws IOException {

    // initialize Search and Manual Lexicon
    boolean useMatoll = ProjectConfiguration.useMatoll();
    boolean useManualLexicon = ProjectConfiguration.useManualLexicon();

    Search.useMatoll(useMatoll);
    ManualLexicon.useManualLexicon(useManualLexicon);

    String datasetName = ProjectConfiguration.getDatasetName();

    int maxWords = ProjectConfiguration.getMaxWords();

    int numberOfEpochs = ProjectConfiguration.getNumberOfEpochs();
    int numberOfSamplingSteps = ProjectConfiguration.getNumberOfSamplingSteps();
    int numberKSamples = ProjectConfiguration.getNumberOfKSamples();

    String evaluatorName = ProjectConfiguration.getEvaluatorName();
    String samplingMethod = ProjectConfiguration.getSamplingMethod();

    log.info("START");

    log.info("\nMax tokens : " + maxWords + "\n");

    System.out.println("\nMax tokens : " + maxWords + "\n");

    CandidateRetriever retriever =
        new CandidateRetrieverOnLucene(
            false, "resourceIndex", "classIndex", "predicateIndex", "matollIndex");
    WordNetAnalyzer wordNet = new WordNetAnalyzer("src/main/resources/WordNet-3.0/dict");

    Search indexLookUp = new Search(retriever, wordNet);

    Evaluator evaluator = new QueryEvaluator();

    if (evaluatorName.equals("mix")) {
      evaluator = new MixEvaluator();
    }

    // get stop word list for Explorer, not to query words like "the, is, have ..."
    // create dudes
    HashMap<Integer, RDFDUDES> dudes = new HashMap<Integer, RDFDUDES>();
    ExpressionFactory expressions = new ExpressionFactory();

    RDFDUDES someIndividual = new RDFDUDES(RDFDUDES.Type.INDIVIDUAL);
    RDFDUDES someProperty = new RDFDUDES(RDFDUDES.Type.PROPERTY, "1", "2");
    RDFDUDES who =
        expressions.wh(
            "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
            "http://dbpedia.org/ontology/Person");
    RDFDUDES what = expressions.what();
    RDFDUDES when =
        expressions.wh(
            "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
            "http://www.w3.org/2001/XMLSchema#DateTime");
    RDFDUDES where =
        expressions.wh(
            "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
            "http://dbpedia.org/ontology/Location");
    RDFDUDES which = expressions.which("1");
    RDFDUDES empty = null;
    RDFDUDES someClass = new RDFDUDES(RDFDUDES.Type.CLASS);
    RDFDUDES underSpecified = new RDFDUDES(RDFDUDES.Type.CLASS, "1");
    RDFDUDES did = expressions.did();
    RDFDUDES is = expressions.copula("1", "2");
    RDFDUDES how = expressions.how();

    // instantiate the class property
    someClass.instantiateProperty("http://www.w3.org/1999/02/22-rdf-syntax-ns#type");

    HashMap<Integer, String> assignedDUDES = new LinkedHashMap<>();

    dudes.put(0, what);
    dudes.put(1, someProperty);
    dudes.put(2, someIndividual);
    dudes.put(3, someClass);
    dudes.put(4, underSpecified);
    dudes.put(5, which);
    dudes.put(6, empty);
    dudes.put(7, did);
    dudes.put(8, when);
    dudes.put(9, where);
    dudes.put(10, who);
    dudes.put(11, is);
    dudes.put(12, how);

    //
    assignedDUDES.put(0, "What");
    assignedDUDES.put(1, "Property");
    assignedDUDES.put(2, "Individual");
    assignedDUDES.put(3, "Class");
    assignedDUDES.put(4, "UnderSpecifiedClass");
    assignedDUDES.put(5, "Which");
    assignedDUDES.put(6, "Empty");
    assignedDUDES.put(7, "Did");
    assignedDUDES.put(8, "When");
    assignedDUDES.put(9, "Where");
    assignedDUDES.put(10, "Who");
    assignedDUDES.put(11, "Is_Copula");
    assignedDUDES.put(12, "How");

    //        DBpediaEndpoint.setToRemote();
    QueryConstructor.initialize(dudes);
    SPARQLObjectiveFunction objFunction = new SPARQLObjectiveFunction(evaluator);

    // explorers
    NodeExplorer nodeExplorer = new NodeExplorer(dudes, indexLookUp, assignedDUDES);
    SlotExplorer slotExplorer = new SlotExplorer(dudes);
    SwapExplorer swapExplorer = new SwapExplorer(dudes, assignedDUDES, indexLookUp);
    ExpansionExplorer expansionExplorer = new ExpansionExplorer(indexLookUp, dudes, assignedDUDES);

    QueryTypeExplorer queryTypeExplorer = new QueryTypeExplorer();

    //                nodeExplorer.setGreedyExplorer(true);
    // join all into one
    JointExplorer jointExplorer =
        new JointExplorer(
            nodeExplorer,
            slotExplorer,
            expansionExplorer,
            swapExplorer,
            queryTypeExplorer,
            objFunction,
            dudes);

    List<Explorer<DeptDudeState>> explorers = new ArrayList<>();
    explorers.add(jointExplorer);

    /** load the corpus, corpora */
    QALDCorpusLoader loader = new QALDCorpusLoader();
    List<QueryDocument> documents = new ArrayList<>();

    QALDCorpus corpus = loader.load(QALDCorpusLoader.Dataset.valueOf(datasetName));

    // get short text documents
    for (QueryDocument d1 : corpus.getDocuments()) {
      String question = d1.getQuestionString();

      if (d1.getQaldInstance().getAggregation().equals("false")
          && d1.getQaldInstance().getOnlyDBO().equals("true")
          && d1.getQaldInstance().getHybrid().equals("false")) {

        int numWords = question.split(" ").length;

        if (numWords == maxWords) {
          if (DBpediaEndpoint.isValidQuery(d1.getGoldResult())) {

            //                        System.out.println(d1);
            //                        if (parsedIns.contains(question)) {
            documents.add(d1);
            //                        }
          } else {
            System.out.println("Query doesn't run: " + d1.getGoldResult());
          }
        }
      }
    }

    //        Collections.shuffle(documents);
    List<QueryDocument> train = new ArrayList<>(documents);

    /*
     * In the following, we setup all necessary components for training
     * and testing.
     */
    /*
     * Define an objective function that guides the training procedure.
     */
    /*
     * Define templates that are responsible to generate
     * factors/features to score intermediate, generated states.
     */
    List<AbstractTemplate<QueryDocument, DeptDudeState, ?>> templates = new ArrayList<>();
    try {
      templates.add(new ModifiedSimilarityTemplate());
      templates.add(new DependencyTemplate(assignedDUDES));
      templates.add(new SlotFillingTemplate(dudes, assignedDUDES));
      templates.add(new ExpansionTemplate());
      templates.add(new DomainRangeTemplate(dudes));
      templates.add(new QueryTypeTemplate());
      templates.add(new SwapTemplate());
      //            templates.add(new NodeURITemplate(dudes));
    } catch (Exception e1) {
      e1.printStackTrace();
      System.exit(1);
    }

    /*
     * Create the scorer object that computes a score from the features of a
     * factor and the weight vectors of the templates.
     */
    Scorer scorer = new LinearScorer();

    /*
     * Define a model and provide it with the necessary templates.
     */
    Model<QueryDocument, DeptDudeState> model = new Model<>(scorer, templates);
    model.setMultiThreaded(true);

    /*
     * Create an Initializer that is responsible for providing an initial
     * state for the sampling chain given a sentence.
     */
    Initializer<QueryDocument, DeptDudeState> trainInitializer =
        new EmptyDUDEInitializer(assignedDUDES);

    /*
     * Define the explorers that will provide "neighboring" states given a
     * starting state. The sampler will select one of these states as a
     * successor state and, thus, perform the sampling procedure.
     */
    StoppingCriterion<DeptDudeState> objectiveOneCriterion =
        new StoppingCriterion<DeptDudeState>() {

          @Override
          public boolean checkCondition(List<DeptDudeState> chain, int step) {
            if (chain.isEmpty()) {
              return false;
            }

            double maxScore = chain.get(chain.size() - 1).getObjectiveScore();

            if (maxScore == 1.0) {
              return true;
            }

            int count = 0;
            final int maxCount = 4;

            for (int i = 0; i < chain.size(); i++) {
              if (chain.get(i).getObjectiveScore() >= maxScore) {
                count++;
              }
            }
            return count >= maxCount || step >= numberOfSamplingSteps;
          }
        };

    StoppingCriterion<DeptDudeState> modelSPARQLCriterion =
        new StoppingCriterion<DeptDudeState>() {

          @Override
          public boolean checkCondition(List<DeptDudeState> chain, int step) {
            if (chain.isEmpty()) {
              return false;
            }

            String lastQuery = QueryConstructor.getSPARQLQuery(chain.get(chain.size() - 1));

            QueryEvaluator evaluator = new QueryEvaluator();

            int count = 0;
            final int maxCount = 2;

            for (int i = chain.size() - 2; i >= 0; i--) {
              String query = QueryConstructor.getSPARQLQuery(chain.get(i));
              double sim = evaluator.evaluate(query, lastQuery);

              if (sim == 1.0) {
                count++;
              }
            }

            return count >= maxCount || step >= numberOfSamplingSteps;
          }
        };

    TopKSampler<QueryDocument, DeptDudeState, String> sampler =
        new TopKSampler<>(model, scorer, objFunction, explorers, objectiveOneCriterion);
    sampler.setSamplingStrategy(ListSamplingStrategies.topKBestObjectiveSamplingStrategy(20));
    sampler.setAcceptStrategy(AcceptStrategies.objectiveAccept());

    /*
     * Define a learning strategy. The learner will receive state pairs
     * which can be used to update the models parameters.
     */
    Learner learner = null;
    if (ProjectConfiguration.getLearner().equals("default")) {
      learner = new DefaultLearner<DeptDudeState>(model, 0.1);
    } else {
      learner = new ForceUpdateLearner<DeptDudeState>(model, 0.1);
    }

    /*
     * The trainer will loop over the data and invoke sampling and learning.
     * Additionally, it can invoke predictions on new data.
     */
    Trainer trainer = new Trainer();
    trainer.addEpochCallback(
        new EpochCallback() {

          @Override
          public void onStartEpoch(
              Trainer caller, int epoch, int numberOfEpochs, int numberOfInstances) {

            if (samplingMethod.equals("objective")) {
              sampler.setSamplingStrategy(
                  ListSamplingStrategies.topKBestObjectiveSamplingStrategy(numberKSamples));
              sampler.setAcceptStrategy(AcceptStrategies.objectiveAccept());
              sampler.setStoppingCriterion(objectiveOneCriterion);
              System.out.println("Trained model:\n" + model.toDetailedString());
              log.info("Switched to Objective Score");
            }
            // model score ranking
            else if (samplingMethod.equals("model")) {
              sampler.setSamplingStrategy(
                  ListSamplingStrategies.topKModelSamplingStrategy(numberKSamples));
              sampler.setAcceptStrategy(AcceptStrategies.strictModelAccept());
              sampler.setStoppingCriterion(modelSPARQLCriterion);
              System.out.println("Trained model:\n" + model.toDetailedString());
              log.info("Switched to Model Score");
            }
            // hybrid
            else {
              if (epoch % 2 == 0) {
                sampler.setSamplingStrategy(
                    ListSamplingStrategies.topKBestObjectiveSamplingStrategy(numberKSamples));
                sampler.setAcceptStrategy(AcceptStrategies.objectiveAccept());
                sampler.setStoppingCriterion(objectiveOneCriterion);
                System.out.println("Trained model:\n" + model.toDetailedString());
                log.info("Switched to Objective Score");
              } else {
                sampler.setSamplingStrategy(
                    ListSamplingStrategies.topKModelSamplingStrategy(numberKSamples));
                sampler.setAcceptStrategy(AcceptStrategies.strictModelAccept());
                sampler.setStoppingCriterion(modelSPARQLCriterion);
                System.out.println("Trained model:\n" + model.toDetailedString());
                log.info("Switched to Model Score");
              }
            }

            //                EpochCallback.super.onStartEpoch(caller, epoch, numberOfEpochs,
            // numberOfInstances); //To change body of generated methods, choose Tools | Templates.
          }
        });
    trainer.train(sampler, trainInitializer, learner, train, numberOfEpochs);

    model.saveModelToFile("models", "model");
    System.out.println(model.toDetailedString());
    log.info("\nModel: \n" + model.toDetailedString() + "\n");
  }