예제 #1
0
파일: GRULM.java 프로젝트: chagge/JRNN
  private void train(CharText ctext, double lr) {
    Map<Integer, String> indexChar = ctext.getIndexChar();
    Map<String, DoubleMatrix> charVector = ctext.getCharVector();
    List<String> sequence = ctext.getSequence();
    for (int i = 0; i < 100; i++) {
      double error = 0;
      double num = 0;
      double start = System.currentTimeMillis();
      for (int s = 0; s < sequence.size(); s++) {
        String seq = sequence.get(s);
        if (seq.length() < 3) {
          continue;
        }

        Map<String, DoubleMatrix> acts = new HashMap<>();
        // forward pass
        System.out.print(String.valueOf(seq.charAt(0)));
        for (int t = 0; t < seq.length() - 1; t++) {
          DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
          acts.put("x" + t, xt);

          gru.active(t, acts);

          DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
          acts.put("py" + t, predcitYt);
          DoubleMatrix trueYt = charVector.get(String.valueOf(seq.charAt(t + 1)));
          acts.put("y" + t, trueYt);

          System.out.print(indexChar.get(predcitYt.argmax()));
          error += LossFunction.getMeanCategoricalCrossEntropy(predcitYt, trueYt);
        }

        System.out.println();

        // bptt
        gru.bptt(acts, seq.length() - 2, lr);

        num += seq.length();
      }
      System.out.println(
          "Iter = "
              + i
              + ", error = "
              + error / num
              + ", time = "
              + (System.currentTimeMillis() - start) / 1000
              + "s");
    }
  }