@Override
  public void updateReward(User user, Article a, boolean clicked) {
    String aId = a.getId();
    // Collect Variables
    RealMatrix xta = MatrixUtils.createColumnRealMatrix(a.getFeatures());
    RealMatrix zta = makeZta(MatrixUtils.createColumnRealMatrix(user.getFeatures()), xta);

    RealMatrix Aa = AMap.get(aId);
    RealMatrix ba = bMap.get(aId);
    RealMatrix Ba = BMap.get(aId);

    // Find common transpose/inverse to save computation
    RealMatrix AaInverse = MatrixUtils.inverse(Aa);
    RealMatrix BaTranspose = Ba.transpose();
    RealMatrix xtaTranspose = xta.transpose();
    RealMatrix ztaTranspose = zta.transpose();

    // Update
    A0 = A0.add(BaTranspose.multiply(AaInverse).multiply(Ba));
    b0 = b0.add(BaTranspose.multiply(AaInverse).multiply(ba));
    Aa = Aa.add(xta.multiply(xtaTranspose));
    AMap.put(aId, Aa);
    Ba = Ba.add(xta.multiply(ztaTranspose));
    BMap.put(aId, Ba);
    if (clicked) {
      ba = ba.add(xta);
      bMap.put(aId, ba);
    }

    // Update A0 and b0 with the new values
    A0 =
        A0.add(zta.multiply(ztaTranspose))
            .subtract(Ba.transpose().multiply(MatrixUtils.inverse(Aa).multiply(Ba)));
    b0 = b0.subtract(Ba.transpose().multiply(MatrixUtils.inverse(Aa)).multiply(ba));
    if (clicked) {
      b0 = b0.add(zta);
    }
  }
  @Override
  public Article chooseArm(User user, List<Article> articles) {
    Article bestA = null;
    double bestArmP = Double.MIN_VALUE;

    RealMatrix Aa;
    RealMatrix Ba;
    RealMatrix ba;

    for (Article a : articles) {
      String aId = a.getId();
      if (!AMap.containsKey(aId)) {
        Aa = MatrixUtils.createRealIdentityMatrix(6);
        AMap.put(aId, Aa); // set as identity for now and we will update
        // in reward

        double[] zeros = {0, 0, 0, 0, 0, 0};
        ba = MatrixUtils.createColumnRealMatrix(zeros);
        bMap.put(aId, ba);

        double[][] BMapZeros = new double[6][36];
        for (double[] row : BMapZeros) {
          Arrays.fill(row, 0.0);
        }
        Ba = MatrixUtils.createRealMatrix(BMapZeros);
        BMap.put(aId, Ba);
      } else {
        Aa = AMap.get(aId);
        ba = bMap.get(aId);
        Ba = BMap.get(aId);
      }

      // Make column vector out of features
      RealMatrix xta = MatrixUtils.createColumnRealMatrix(a.getFeatures());
      RealMatrix zta = makeZta(MatrixUtils.createColumnRealMatrix(user.getFeatures()), xta);

      // Set up common variables
      RealMatrix A0Inverse = MatrixUtils.inverse(A0);
      RealMatrix AaInverse = MatrixUtils.inverse(Aa);
      RealMatrix ztaTranspose = zta.transpose();
      RealMatrix BaTranspose = Ba.transpose();
      RealMatrix xtaTranspose = xta.transpose();

      // Find theta
      RealMatrix theta = AaInverse.multiply(ba.subtract(Ba.multiply(BetaHat)));
      // Find sta
      RealMatrix staMatrix = ztaTranspose.multiply(A0Inverse).multiply(zta);
      staMatrix =
          staMatrix.subtract(
              ztaTranspose
                  .multiply(A0Inverse)
                  .multiply(BaTranspose)
                  .multiply(AaInverse)
                  .multiply(xta)
                  .scalarMultiply(2));
      staMatrix = staMatrix.add(xtaTranspose.multiply(AaInverse).multiply(xta));
      staMatrix =
          staMatrix.add(
              xtaTranspose
                  .multiply(AaInverse)
                  .multiply(Ba)
                  .multiply(A0Inverse)
                  .multiply(BaTranspose)
                  .multiply(AaInverse)
                  .multiply(xta));

      // Find pta for arm
      RealMatrix ptaMatrix = ztaTranspose.multiply(BetaHat);
      ptaMatrix = ptaMatrix.add(xtaTranspose.multiply(theta));
      double ptaVal = ptaMatrix.getData()[0][0];
      double staVal = staMatrix.getData()[0][0];
      ptaVal = ptaVal + alpha * Math.sqrt(staVal);

      // Update argmax
      if (ptaVal > bestArmP) {
        bestArmP = ptaVal;
        bestA = a;
      }
    }
    return bestA;
  }
示例#3
0
  @Override
  protected Object execute(Object[] data) {
    if (data[0] == null) {
      throw new ExecutionPlanRuntimeException(
          "Invalid input given to kf:kalmanFilter() "
              + "function. First argument should be a double");
    }
    if (data[1] == null) {
      throw new ExecutionPlanRuntimeException(
          "Invalid input given to kf:kalmanFilter() "
              + "function. Second argument should be a double");
    }
    if (data.length == 2) {
      double measuredValue = (Double) data[0]; // to remain as the initial state
      if (prevEstimatedValue == 0) {
        transition = 1;
        variance = 1000;
        measurementNoiseSD = (Double) data[1];
        prevEstimatedValue = measuredValue;
      }
      prevEstimatedValue = transition * prevEstimatedValue;
      double kalmanGain = variance / (variance + measurementNoiseSD);
      prevEstimatedValue = prevEstimatedValue + kalmanGain * (measuredValue - prevEstimatedValue);
      variance = (1 - kalmanGain) * variance;
      return prevEstimatedValue;
    } else {
      if (data[2] == null) {
        throw new ExecutionPlanRuntimeException(
            "Invalid input given to kf:kalmanFilter() "
                + "function. Third argument should be a double");
      }
      if (data[3] == null) {
        throw new ExecutionPlanRuntimeException(
            "Invalid input given to kf:kalmanFilter() "
                + "function. Fourth argument should be a long");
      }

      double measuredXValue = (Double) data[0];
      double measuredChangingRate = (Double) data[1];
      double measurementNoiseSD = (Double) data[2];
      long timestamp = (Long) data[3];
      long timestampDiff;
      double[][] measuredValues = {{measuredXValue}, {measuredChangingRate}};

      if (measurementMatrixH == null) {
        timestampDiff = 1;
        double[][] varianceValues = {{1000, 0}, {0, 1000}};
        double[][] measurementValues = {{1, 0}, {0, 1}};
        measurementMatrixH = MatrixUtils.createRealMatrix(measurementValues);
        varianceMatrixP = MatrixUtils.createRealMatrix(varianceValues);
        prevMeasuredMatrix = MatrixUtils.createRealMatrix(measuredValues);
      } else {
        timestampDiff = (timestamp - prevTimestamp);
      }
      double[][] Rvalues = {{measurementNoiseSD, 0}, {0, measurementNoiseSD}};
      RealMatrix rMatrix = MatrixUtils.createRealMatrix(Rvalues);
      double[][] transitionValues = {{1d, timestampDiff}, {0d, 1d}};
      RealMatrix transitionMatrixA = MatrixUtils.createRealMatrix(transitionValues);
      RealMatrix measuredMatrixX = MatrixUtils.createRealMatrix(measuredValues);

      // Xk = (A * Xk-1)
      prevMeasuredMatrix = transitionMatrixA.multiply(prevMeasuredMatrix);

      // Pk = (A * P * AT) + Q
      varianceMatrixP =
          (transitionMatrixA.multiply(varianceMatrixP)).multiply(transitionMatrixA.transpose());

      // S = (H * P * HT) + R
      RealMatrix S =
          ((measurementMatrixH.multiply(varianceMatrixP)).multiply(measurementMatrixH.transpose()))
              .add(rMatrix);
      RealMatrix S_1 = new LUDecomposition(S).getSolver().getInverse();

      // P * HT * S-1
      RealMatrix kalmanGainMatrix =
          (varianceMatrixP.multiply(measurementMatrixH.transpose())).multiply(S_1);

      // Xk = Xk + kalmanGainMatrix (Zk - HkXk )
      prevMeasuredMatrix =
          prevMeasuredMatrix.add(
              kalmanGainMatrix.multiply(
                  (measuredMatrixX.subtract(measurementMatrixH.multiply(prevMeasuredMatrix)))));

      // Pk = Pk - K.Hk.Pk
      varianceMatrixP =
          varianceMatrixP.subtract(
              (kalmanGainMatrix.multiply(measurementMatrixH)).multiply(varianceMatrixP));

      prevTimestamp = timestamp;
      return prevMeasuredMatrix.getRow(0)[0];
    }
  }
示例#4
0
  private double generalizedCorrelationRatio(SampleIterator it, int inputDim, int out) {
    Map<Double, Integer> n_y = new HashMap<>();
    Map<Double, MultivariateSummaryStatistics> stat_y = new HashMap<>();
    List<RealMatrix> x = new ArrayList<>();
    MultivariateSummaryStatistics stat = new MultivariateSummaryStatistics(inputDim, unbiased);

    for (int i = 0; i < maxSamples && it.hasNext(); i++) {
      Sample sample = it.next();
      double[] input = sample.getEncodedInput().toArray();
      double output = sample.getEncodedOutput().getEntry(out);
      if (!n_y.containsKey(output)) {
        n_y.put(output, 0);
        stat_y.put(output, new MultivariateSummaryStatistics(inputDim, unbiased));
      }

      injectNoise(input);
      n_y.put(output, n_y.get(output) + 1);
      stat_y.get(output).addValue(input);
      x.add(new Array2DRowRealMatrix(input));
      stat.addValue(input);
    }

    RealMatrix x_sum = new Array2DRowRealMatrix(stat.getSum());
    Map<Double, RealMatrix> x_y_sum = new HashMap<>();
    for (Entry<Double, MultivariateSummaryStatistics> entry : stat_y.entrySet()) {
      x_y_sum.put(entry.getKey(), new Array2DRowRealMatrix(entry.getValue().getSum()));
    }

    RealMatrix H = new Array2DRowRealMatrix(inputDim, inputDim);
    RealMatrix temp = new Array2DRowRealMatrix(inputDim, inputDim);

    for (double key : n_y.keySet()) {
      temp =
          temp.add(
              x_y_sum
                  .get(key)
                  .multiply(x_y_sum.get(key).transpose())
                  .scalarMultiply(1.0 / n_y.get(key)));
    }
    H = temp.subtract(x_sum.multiply(x_sum.transpose()).scalarMultiply(1.0 / x.size()));

    RealMatrix E = new Array2DRowRealMatrix(inputDim, inputDim);
    for (RealMatrix m : x) {
      E = E.add(m.multiply(m.transpose()));
    }
    E = E.subtract(temp);

    List<Integer> zeroColumns = findZeroColumns(E);
    E = removeZeroColumns(E, zeroColumns);
    H = removeZeroColumns(H, zeroColumns);

    Matrix JE = new Matrix(E.getData());
    Matrix JH = new Matrix(H.getData());

    if (JE.rank() < JE.getRowDimension()) {
      Log.write(this, "Some error occurred (E matrix is singular)");
      return -1;
    } else {
      double lambda;
      if (useEigenvalues) {
        Matrix L = JE.inverse().times(JH);
        double[] eigs = L.eig().getRealEigenvalues();
        Arrays.sort(eigs);

        lambda = 1;
        int nonNullEigs = n_y.keySet().size() - 1;
        for (int i = eigs.length - nonNullEigs; i < eigs.length; i++) {
          if (Math.abs(eigs[i]) < zeroThreshold) {
            Log.write(this, "Some error occurred (E matrix has too many null eigenvalues)");
            return -1;
          }
          lambda *= 1.0 / (1.0 + eigs[i]);
        }
      } else {
        Matrix sum = JE.plus(JH);
        if (sum.rank() < sum.getRowDimension()) {
          Log.write(this, "Some error occourred (E+H is singular");
          return -1;
        }
        lambda = JE.det() / sum.det();
      }

      return Math.sqrt(1 - lambda);
    }
  }