@Override public void updateReward(User user, Article a, boolean clicked) { String aId = a.getId(); // Collect Variables RealMatrix xta = MatrixUtils.createColumnRealMatrix(a.getFeatures()); RealMatrix zta = makeZta(MatrixUtils.createColumnRealMatrix(user.getFeatures()), xta); RealMatrix Aa = AMap.get(aId); RealMatrix ba = bMap.get(aId); RealMatrix Ba = BMap.get(aId); // Find common transpose/inverse to save computation RealMatrix AaInverse = MatrixUtils.inverse(Aa); RealMatrix BaTranspose = Ba.transpose(); RealMatrix xtaTranspose = xta.transpose(); RealMatrix ztaTranspose = zta.transpose(); // Update A0 = A0.add(BaTranspose.multiply(AaInverse).multiply(Ba)); b0 = b0.add(BaTranspose.multiply(AaInverse).multiply(ba)); Aa = Aa.add(xta.multiply(xtaTranspose)); AMap.put(aId, Aa); Ba = Ba.add(xta.multiply(ztaTranspose)); BMap.put(aId, Ba); if (clicked) { ba = ba.add(xta); bMap.put(aId, ba); } // Update A0 and b0 with the new values A0 = A0.add(zta.multiply(ztaTranspose)) .subtract(Ba.transpose().multiply(MatrixUtils.inverse(Aa).multiply(Ba))); b0 = b0.subtract(Ba.transpose().multiply(MatrixUtils.inverse(Aa)).multiply(ba)); if (clicked) { b0 = b0.add(zta); } }
@Override public Article chooseArm(User user, List<Article> articles) { Article bestA = null; double bestArmP = Double.MIN_VALUE; RealMatrix Aa; RealMatrix Ba; RealMatrix ba; for (Article a : articles) { String aId = a.getId(); if (!AMap.containsKey(aId)) { Aa = MatrixUtils.createRealIdentityMatrix(6); AMap.put(aId, Aa); // set as identity for now and we will update // in reward double[] zeros = {0, 0, 0, 0, 0, 0}; ba = MatrixUtils.createColumnRealMatrix(zeros); bMap.put(aId, ba); double[][] BMapZeros = new double[6][36]; for (double[] row : BMapZeros) { Arrays.fill(row, 0.0); } Ba = MatrixUtils.createRealMatrix(BMapZeros); BMap.put(aId, Ba); } else { Aa = AMap.get(aId); ba = bMap.get(aId); Ba = BMap.get(aId); } // Make column vector out of features RealMatrix xta = MatrixUtils.createColumnRealMatrix(a.getFeatures()); RealMatrix zta = makeZta(MatrixUtils.createColumnRealMatrix(user.getFeatures()), xta); // Set up common variables RealMatrix A0Inverse = MatrixUtils.inverse(A0); RealMatrix AaInverse = MatrixUtils.inverse(Aa); RealMatrix ztaTranspose = zta.transpose(); RealMatrix BaTranspose = Ba.transpose(); RealMatrix xtaTranspose = xta.transpose(); // Find theta RealMatrix theta = AaInverse.multiply(ba.subtract(Ba.multiply(BetaHat))); // Find sta RealMatrix staMatrix = ztaTranspose.multiply(A0Inverse).multiply(zta); staMatrix = staMatrix.subtract( ztaTranspose .multiply(A0Inverse) .multiply(BaTranspose) .multiply(AaInverse) .multiply(xta) .scalarMultiply(2)); staMatrix = staMatrix.add(xtaTranspose.multiply(AaInverse).multiply(xta)); staMatrix = staMatrix.add( xtaTranspose .multiply(AaInverse) .multiply(Ba) .multiply(A0Inverse) .multiply(BaTranspose) .multiply(AaInverse) .multiply(xta)); // Find pta for arm RealMatrix ptaMatrix = ztaTranspose.multiply(BetaHat); ptaMatrix = ptaMatrix.add(xtaTranspose.multiply(theta)); double ptaVal = ptaMatrix.getData()[0][0]; double staVal = staMatrix.getData()[0][0]; ptaVal = ptaVal + alpha * Math.sqrt(staVal); // Update argmax if (ptaVal > bestArmP) { bestArmP = ptaVal; bestA = a; } } return bestA; }
@Override protected Object execute(Object[] data) { if (data[0] == null) { throw new ExecutionPlanRuntimeException( "Invalid input given to kf:kalmanFilter() " + "function. First argument should be a double"); } if (data[1] == null) { throw new ExecutionPlanRuntimeException( "Invalid input given to kf:kalmanFilter() " + "function. Second argument should be a double"); } if (data.length == 2) { double measuredValue = (Double) data[0]; // to remain as the initial state if (prevEstimatedValue == 0) { transition = 1; variance = 1000; measurementNoiseSD = (Double) data[1]; prevEstimatedValue = measuredValue; } prevEstimatedValue = transition * prevEstimatedValue; double kalmanGain = variance / (variance + measurementNoiseSD); prevEstimatedValue = prevEstimatedValue + kalmanGain * (measuredValue - prevEstimatedValue); variance = (1 - kalmanGain) * variance; return prevEstimatedValue; } else { if (data[2] == null) { throw new ExecutionPlanRuntimeException( "Invalid input given to kf:kalmanFilter() " + "function. Third argument should be a double"); } if (data[3] == null) { throw new ExecutionPlanRuntimeException( "Invalid input given to kf:kalmanFilter() " + "function. Fourth argument should be a long"); } double measuredXValue = (Double) data[0]; double measuredChangingRate = (Double) data[1]; double measurementNoiseSD = (Double) data[2]; long timestamp = (Long) data[3]; long timestampDiff; double[][] measuredValues = {{measuredXValue}, {measuredChangingRate}}; if (measurementMatrixH == null) { timestampDiff = 1; double[][] varianceValues = {{1000, 0}, {0, 1000}}; double[][] measurementValues = {{1, 0}, {0, 1}}; measurementMatrixH = MatrixUtils.createRealMatrix(measurementValues); varianceMatrixP = MatrixUtils.createRealMatrix(varianceValues); prevMeasuredMatrix = MatrixUtils.createRealMatrix(measuredValues); } else { timestampDiff = (timestamp - prevTimestamp); } double[][] Rvalues = {{measurementNoiseSD, 0}, {0, measurementNoiseSD}}; RealMatrix rMatrix = MatrixUtils.createRealMatrix(Rvalues); double[][] transitionValues = {{1d, timestampDiff}, {0d, 1d}}; RealMatrix transitionMatrixA = MatrixUtils.createRealMatrix(transitionValues); RealMatrix measuredMatrixX = MatrixUtils.createRealMatrix(measuredValues); // Xk = (A * Xk-1) prevMeasuredMatrix = transitionMatrixA.multiply(prevMeasuredMatrix); // Pk = (A * P * AT) + Q varianceMatrixP = (transitionMatrixA.multiply(varianceMatrixP)).multiply(transitionMatrixA.transpose()); // S = (H * P * HT) + R RealMatrix S = ((measurementMatrixH.multiply(varianceMatrixP)).multiply(measurementMatrixH.transpose())) .add(rMatrix); RealMatrix S_1 = new LUDecomposition(S).getSolver().getInverse(); // P * HT * S-1 RealMatrix kalmanGainMatrix = (varianceMatrixP.multiply(measurementMatrixH.transpose())).multiply(S_1); // Xk = Xk + kalmanGainMatrix (Zk - HkXk ) prevMeasuredMatrix = prevMeasuredMatrix.add( kalmanGainMatrix.multiply( (measuredMatrixX.subtract(measurementMatrixH.multiply(prevMeasuredMatrix))))); // Pk = Pk - K.Hk.Pk varianceMatrixP = varianceMatrixP.subtract( (kalmanGainMatrix.multiply(measurementMatrixH)).multiply(varianceMatrixP)); prevTimestamp = timestamp; return prevMeasuredMatrix.getRow(0)[0]; } }
private double generalizedCorrelationRatio(SampleIterator it, int inputDim, int out) { Map<Double, Integer> n_y = new HashMap<>(); Map<Double, MultivariateSummaryStatistics> stat_y = new HashMap<>(); List<RealMatrix> x = new ArrayList<>(); MultivariateSummaryStatistics stat = new MultivariateSummaryStatistics(inputDim, unbiased); for (int i = 0; i < maxSamples && it.hasNext(); i++) { Sample sample = it.next(); double[] input = sample.getEncodedInput().toArray(); double output = sample.getEncodedOutput().getEntry(out); if (!n_y.containsKey(output)) { n_y.put(output, 0); stat_y.put(output, new MultivariateSummaryStatistics(inputDim, unbiased)); } injectNoise(input); n_y.put(output, n_y.get(output) + 1); stat_y.get(output).addValue(input); x.add(new Array2DRowRealMatrix(input)); stat.addValue(input); } RealMatrix x_sum = new Array2DRowRealMatrix(stat.getSum()); Map<Double, RealMatrix> x_y_sum = new HashMap<>(); for (Entry<Double, MultivariateSummaryStatistics> entry : stat_y.entrySet()) { x_y_sum.put(entry.getKey(), new Array2DRowRealMatrix(entry.getValue().getSum())); } RealMatrix H = new Array2DRowRealMatrix(inputDim, inputDim); RealMatrix temp = new Array2DRowRealMatrix(inputDim, inputDim); for (double key : n_y.keySet()) { temp = temp.add( x_y_sum .get(key) .multiply(x_y_sum.get(key).transpose()) .scalarMultiply(1.0 / n_y.get(key))); } H = temp.subtract(x_sum.multiply(x_sum.transpose()).scalarMultiply(1.0 / x.size())); RealMatrix E = new Array2DRowRealMatrix(inputDim, inputDim); for (RealMatrix m : x) { E = E.add(m.multiply(m.transpose())); } E = E.subtract(temp); List<Integer> zeroColumns = findZeroColumns(E); E = removeZeroColumns(E, zeroColumns); H = removeZeroColumns(H, zeroColumns); Matrix JE = new Matrix(E.getData()); Matrix JH = new Matrix(H.getData()); if (JE.rank() < JE.getRowDimension()) { Log.write(this, "Some error occurred (E matrix is singular)"); return -1; } else { double lambda; if (useEigenvalues) { Matrix L = JE.inverse().times(JH); double[] eigs = L.eig().getRealEigenvalues(); Arrays.sort(eigs); lambda = 1; int nonNullEigs = n_y.keySet().size() - 1; for (int i = eigs.length - nonNullEigs; i < eigs.length; i++) { if (Math.abs(eigs[i]) < zeroThreshold) { Log.write(this, "Some error occurred (E matrix has too many null eigenvalues)"); return -1; } lambda *= 1.0 / (1.0 + eigs[i]); } } else { Matrix sum = JE.plus(JH); if (sum.rank() < sum.getRowDimension()) { Log.write(this, "Some error occourred (E+H is singular"); return -1; } lambda = JE.det() / sum.det(); } return Math.sqrt(1 - lambda); } }