private void train(CharText ctext, double lr) { Map<Integer, String> indexChar = ctext.getIndexChar(); Map<String, DoubleMatrix> charVector = ctext.getCharVector(); List<String> sequence = ctext.getSequence(); for (int i = 0; i < 100; i++) { double error = 0; double num = 0; double start = System.currentTimeMillis(); for (int s = 0; s < sequence.size(); s++) { String seq = sequence.get(s); if (seq.length() < 3) { continue; } Map<String, DoubleMatrix> acts = new HashMap<>(); // forward pass System.out.print(String.valueOf(seq.charAt(0))); for (int t = 0; t < seq.length() - 1; t++) { DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t))); acts.put("x" + t, xt); gru.active(t, acts); DoubleMatrix predcitYt = gru.decode(acts.get("h" + t)); acts.put("py" + t, predcitYt); DoubleMatrix trueYt = charVector.get(String.valueOf(seq.charAt(t + 1))); acts.put("y" + t, trueYt); System.out.print(indexChar.get(predcitYt.argmax())); error += LossFunction.getMeanCategoricalCrossEntropy(predcitYt, trueYt); } System.out.println(); // bptt gru.bptt(acts, seq.length() - 2, lr); num += seq.length(); } System.out.println( "Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s"); } }