@Override
  public void updateReward(User user, Article a, boolean clicked) {
    String aId = a.getId();
    // Collect Variables
    RealMatrix xta = MatrixUtils.createColumnRealMatrix(a.getFeatures());
    RealMatrix zta = makeZta(MatrixUtils.createColumnRealMatrix(user.getFeatures()), xta);

    RealMatrix Aa = AMap.get(aId);
    RealMatrix ba = bMap.get(aId);
    RealMatrix Ba = BMap.get(aId);

    // Find common transpose/inverse to save computation
    RealMatrix AaInverse = MatrixUtils.inverse(Aa);
    RealMatrix BaTranspose = Ba.transpose();
    RealMatrix xtaTranspose = xta.transpose();
    RealMatrix ztaTranspose = zta.transpose();

    // Update
    A0 = A0.add(BaTranspose.multiply(AaInverse).multiply(Ba));
    b0 = b0.add(BaTranspose.multiply(AaInverse).multiply(ba));
    Aa = Aa.add(xta.multiply(xtaTranspose));
    AMap.put(aId, Aa);
    Ba = Ba.add(xta.multiply(ztaTranspose));
    BMap.put(aId, Ba);
    if (clicked) {
      ba = ba.add(xta);
      bMap.put(aId, ba);
    }

    // Update A0 and b0 with the new values
    A0 =
        A0.add(zta.multiply(ztaTranspose))
            .subtract(Ba.transpose().multiply(MatrixUtils.inverse(Aa).multiply(Ba)));
    b0 = b0.subtract(Ba.transpose().multiply(MatrixUtils.inverse(Aa)).multiply(ba));
    if (clicked) {
      b0 = b0.add(zta);
    }
  }
  @Override
  public Article chooseArm(User user, List<Article> articles) {
    Article bestA = null;
    double bestArmP = Double.MIN_VALUE;

    RealMatrix Aa;
    RealMatrix Ba;
    RealMatrix ba;

    for (Article a : articles) {
      String aId = a.getId();
      if (!AMap.containsKey(aId)) {
        Aa = MatrixUtils.createRealIdentityMatrix(6);
        AMap.put(aId, Aa); // set as identity for now and we will update
        // in reward

        double[] zeros = {0, 0, 0, 0, 0, 0};
        ba = MatrixUtils.createColumnRealMatrix(zeros);
        bMap.put(aId, ba);

        double[][] BMapZeros = new double[6][36];
        for (double[] row : BMapZeros) {
          Arrays.fill(row, 0.0);
        }
        Ba = MatrixUtils.createRealMatrix(BMapZeros);
        BMap.put(aId, Ba);
      } else {
        Aa = AMap.get(aId);
        ba = bMap.get(aId);
        Ba = BMap.get(aId);
      }

      // Make column vector out of features
      RealMatrix xta = MatrixUtils.createColumnRealMatrix(a.getFeatures());
      RealMatrix zta = makeZta(MatrixUtils.createColumnRealMatrix(user.getFeatures()), xta);

      // Set up common variables
      RealMatrix A0Inverse = MatrixUtils.inverse(A0);
      RealMatrix AaInverse = MatrixUtils.inverse(Aa);
      RealMatrix ztaTranspose = zta.transpose();
      RealMatrix BaTranspose = Ba.transpose();
      RealMatrix xtaTranspose = xta.transpose();

      // Find theta
      RealMatrix theta = AaInverse.multiply(ba.subtract(Ba.multiply(BetaHat)));
      // Find sta
      RealMatrix staMatrix = ztaTranspose.multiply(A0Inverse).multiply(zta);
      staMatrix =
          staMatrix.subtract(
              ztaTranspose
                  .multiply(A0Inverse)
                  .multiply(BaTranspose)
                  .multiply(AaInverse)
                  .multiply(xta)
                  .scalarMultiply(2));
      staMatrix = staMatrix.add(xtaTranspose.multiply(AaInverse).multiply(xta));
      staMatrix =
          staMatrix.add(
              xtaTranspose
                  .multiply(AaInverse)
                  .multiply(Ba)
                  .multiply(A0Inverse)
                  .multiply(BaTranspose)
                  .multiply(AaInverse)
                  .multiply(xta));

      // Find pta for arm
      RealMatrix ptaMatrix = ztaTranspose.multiply(BetaHat);
      ptaMatrix = ptaMatrix.add(xtaTranspose.multiply(theta));
      double ptaVal = ptaMatrix.getData()[0][0];
      double staVal = staMatrix.getData()[0][0];
      ptaVal = ptaVal + alpha * Math.sqrt(staVal);

      // Update argmax
      if (ptaVal > bestArmP) {
        bestArmP = ptaVal;
        bestA = a;
      }
    }
    return bestA;
  }
  /**
   * Tangent normalize a coverage profile.
   *
   * <p>Notes about the Spark tangent normalization can be found in docs/PoN/
   *
   * @param pon Not {@code null}
   * @param targetFactorNormalizedCounts ReadCountCollection of counts that have already been
   *     normalized fully (typically, including the target factor normalization). I.e. a coverage
   *     profile The column names should be intact. Not {@code null} See {@link
   *     TangentNormalizer::createCoverageProfile}
   * @return never {@code null}
   */
  private static TangentNormalizationResult tangentNormalize(
      final PoN pon, final ReadCountCollection targetFactorNormalizedCounts, JavaSparkContext ctx) {

    Utils.nonNull(pon, "PoN cannot be null.");
    Utils.nonNull(targetFactorNormalizedCounts, "targetFactorNormalizedCounts cannot be null.");
    Utils.nonNull(
        targetFactorNormalizedCounts.columnNames(),
        "targetFactorNormalizedCounts column names cannot be null.");
    ParamUtils.isPositive(
        targetFactorNormalizedCounts.columnNames().size(),
        "targetFactorNormalizedCounts column names cannot be an empty list.");

    final Case2PoNTargetMapper targetMapper =
        new Case2PoNTargetMapper(targetFactorNormalizedCounts.targets(), pon.getPanelTargetNames());

    // The input counts with rows (targets) sorted so that they match the PoN's order.
    final RealMatrix tangentNormalizationRawInputCounts =
        targetMapper.fromCaseToPoNCounts(targetFactorNormalizedCounts.counts());

    // We prepare the counts for tangent normalization.
    final RealMatrix tangentNormalizationInputCounts =
        composeTangentNormalizationInputMatrix(tangentNormalizationRawInputCounts);

    if (ctx == null) {

      // Calculate the beta-hats for the input read count columns (samples).
      logger.info("Calculating beta hats...");
      final RealMatrix tangentBetaHats =
          pon.betaHats(tangentNormalizationInputCounts, true, EPSILON);

      // Actual tangent normalization step.
      logger.info(
          "Performing actual tangent normalization ("
              + tangentNormalizationInputCounts.getColumnDimension()
              + " columns)...");
      final RealMatrix tangentNormalizedCounts =
          pon.tangentNormalization(tangentNormalizationInputCounts, tangentBetaHats, true);

      // Output the tangent normalized counts.
      logger.info("Post-processing tangent normalization results...");
      final ReadCountCollection tangentNormalized =
          targetMapper.fromPoNtoCaseCountCollection(
              tangentNormalizedCounts, targetFactorNormalizedCounts.columnNames());
      final ReadCountCollection preTangentNormalized =
          targetMapper.fromPoNtoCaseCountCollection(
              tangentNormalizationInputCounts, targetFactorNormalizedCounts.columnNames());

      return new TangentNormalizationResult(
          tangentNormalized, preTangentNormalized, tangentBetaHats, targetFactorNormalizedCounts);

    } else {

      /*
      Using Spark:  the code here is a little more complex for optimization purposes.

      Please see notes in docs/PoN ...

      Ahat^T = (C^T P^T) A^T
      Therefore, C^T is the RowMatrix

      pinv: P
      panel: A
      projection: Ahat
      cases: C
      betahat: C^T P^T
      tangentNormalizedCounts: C - Ahat
       */
      final RealMatrix pinv = pon.getReducedPanelPInverseCounts();
      final RealMatrix panel = pon.getReducedPanelCounts();

      // Make the C^T a distributed matrix (RowMatrix)
      final RowMatrix caseTDistMat =
          SparkConverter.convertRealMatrixToSparkRowMatrix(
              ctx, tangentNormalizationInputCounts.transpose(), TN_NUM_SLICES_SPARK);

      // Spark local matrices (transposed)
      final Matrix pinvTLocalMat =
          new DenseMatrix(
                  pinv.getRowDimension(),
                  pinv.getColumnDimension(),
                  Doubles.concat(pinv.getData()),
                  true)
              .transpose();
      final Matrix panelTLocalMat =
          new DenseMatrix(
                  panel.getRowDimension(),
                  panel.getColumnDimension(),
                  Doubles.concat(panel.getData()),
                  true)
              .transpose();

      // Calculate the projection transpose in a distributed matrix, then convert to Apache Commons
      // matrix (not transposed)
      final RowMatrix betahatDistMat = caseTDistMat.multiply(pinvTLocalMat);
      final RowMatrix projectionTDistMat = betahatDistMat.multiply(panelTLocalMat);
      final RealMatrix projection =
          SparkConverter.convertSparkRowMatrixToRealMatrix(
                  projectionTDistMat, tangentNormalizationInputCounts.transpose().getRowDimension())
              .transpose();

      // Subtract the cases from the projection
      final RealMatrix tangentNormalizedCounts =
          tangentNormalizationInputCounts.subtract(projection);

      // Construct the result object and return it with the correct targets.
      final ReadCountCollection tangentNormalized =
          targetMapper.fromPoNtoCaseCountCollection(
              tangentNormalizedCounts, targetFactorNormalizedCounts.columnNames());
      final ReadCountCollection preTangentNormalized =
          targetMapper.fromPoNtoCaseCountCollection(
              tangentNormalizationInputCounts, targetFactorNormalizedCounts.columnNames());
      final RealMatrix tangentBetaHats =
          SparkConverter.convertSparkRowMatrixToRealMatrix(
              betahatDistMat, tangentNormalizedCounts.getColumnDimension());
      return new TangentNormalizationResult(
          tangentNormalized,
          preTangentNormalized,
          tangentBetaHats.transpose(),
          targetFactorNormalizedCounts);
    }
  }
示例#4
0
  @Override
  protected Object execute(Object[] data) {
    if (data[0] == null) {
      throw new ExecutionPlanRuntimeException(
          "Invalid input given to kf:kalmanFilter() "
              + "function. First argument should be a double");
    }
    if (data[1] == null) {
      throw new ExecutionPlanRuntimeException(
          "Invalid input given to kf:kalmanFilter() "
              + "function. Second argument should be a double");
    }
    if (data.length == 2) {
      double measuredValue = (Double) data[0]; // to remain as the initial state
      if (prevEstimatedValue == 0) {
        transition = 1;
        variance = 1000;
        measurementNoiseSD = (Double) data[1];
        prevEstimatedValue = measuredValue;
      }
      prevEstimatedValue = transition * prevEstimatedValue;
      double kalmanGain = variance / (variance + measurementNoiseSD);
      prevEstimatedValue = prevEstimatedValue + kalmanGain * (measuredValue - prevEstimatedValue);
      variance = (1 - kalmanGain) * variance;
      return prevEstimatedValue;
    } else {
      if (data[2] == null) {
        throw new ExecutionPlanRuntimeException(
            "Invalid input given to kf:kalmanFilter() "
                + "function. Third argument should be a double");
      }
      if (data[3] == null) {
        throw new ExecutionPlanRuntimeException(
            "Invalid input given to kf:kalmanFilter() "
                + "function. Fourth argument should be a long");
      }

      double measuredXValue = (Double) data[0];
      double measuredChangingRate = (Double) data[1];
      double measurementNoiseSD = (Double) data[2];
      long timestamp = (Long) data[3];
      long timestampDiff;
      double[][] measuredValues = {{measuredXValue}, {measuredChangingRate}};

      if (measurementMatrixH == null) {
        timestampDiff = 1;
        double[][] varianceValues = {{1000, 0}, {0, 1000}};
        double[][] measurementValues = {{1, 0}, {0, 1}};
        measurementMatrixH = MatrixUtils.createRealMatrix(measurementValues);
        varianceMatrixP = MatrixUtils.createRealMatrix(varianceValues);
        prevMeasuredMatrix = MatrixUtils.createRealMatrix(measuredValues);
      } else {
        timestampDiff = (timestamp - prevTimestamp);
      }
      double[][] Rvalues = {{measurementNoiseSD, 0}, {0, measurementNoiseSD}};
      RealMatrix rMatrix = MatrixUtils.createRealMatrix(Rvalues);
      double[][] transitionValues = {{1d, timestampDiff}, {0d, 1d}};
      RealMatrix transitionMatrixA = MatrixUtils.createRealMatrix(transitionValues);
      RealMatrix measuredMatrixX = MatrixUtils.createRealMatrix(measuredValues);

      // Xk = (A * Xk-1)
      prevMeasuredMatrix = transitionMatrixA.multiply(prevMeasuredMatrix);

      // Pk = (A * P * AT) + Q
      varianceMatrixP =
          (transitionMatrixA.multiply(varianceMatrixP)).multiply(transitionMatrixA.transpose());

      // S = (H * P * HT) + R
      RealMatrix S =
          ((measurementMatrixH.multiply(varianceMatrixP)).multiply(measurementMatrixH.transpose()))
              .add(rMatrix);
      RealMatrix S_1 = new LUDecomposition(S).getSolver().getInverse();

      // P * HT * S-1
      RealMatrix kalmanGainMatrix =
          (varianceMatrixP.multiply(measurementMatrixH.transpose())).multiply(S_1);

      // Xk = Xk + kalmanGainMatrix (Zk - HkXk )
      prevMeasuredMatrix =
          prevMeasuredMatrix.add(
              kalmanGainMatrix.multiply(
                  (measuredMatrixX.subtract(measurementMatrixH.multiply(prevMeasuredMatrix)))));

      // Pk = Pk - K.Hk.Pk
      varianceMatrixP =
          varianceMatrixP.subtract(
              (kalmanGainMatrix.multiply(measurementMatrixH)).multiply(varianceMatrixP));

      prevTimestamp = timestamp;
      return prevMeasuredMatrix.getRow(0)[0];
    }
  }
示例#5
0
  private double generalizedCorrelationRatio(SampleIterator it, int inputDim, int out) {
    Map<Double, Integer> n_y = new HashMap<>();
    Map<Double, MultivariateSummaryStatistics> stat_y = new HashMap<>();
    List<RealMatrix> x = new ArrayList<>();
    MultivariateSummaryStatistics stat = new MultivariateSummaryStatistics(inputDim, unbiased);

    for (int i = 0; i < maxSamples && it.hasNext(); i++) {
      Sample sample = it.next();
      double[] input = sample.getEncodedInput().toArray();
      double output = sample.getEncodedOutput().getEntry(out);
      if (!n_y.containsKey(output)) {
        n_y.put(output, 0);
        stat_y.put(output, new MultivariateSummaryStatistics(inputDim, unbiased));
      }

      injectNoise(input);
      n_y.put(output, n_y.get(output) + 1);
      stat_y.get(output).addValue(input);
      x.add(new Array2DRowRealMatrix(input));
      stat.addValue(input);
    }

    RealMatrix x_sum = new Array2DRowRealMatrix(stat.getSum());
    Map<Double, RealMatrix> x_y_sum = new HashMap<>();
    for (Entry<Double, MultivariateSummaryStatistics> entry : stat_y.entrySet()) {
      x_y_sum.put(entry.getKey(), new Array2DRowRealMatrix(entry.getValue().getSum()));
    }

    RealMatrix H = new Array2DRowRealMatrix(inputDim, inputDim);
    RealMatrix temp = new Array2DRowRealMatrix(inputDim, inputDim);

    for (double key : n_y.keySet()) {
      temp =
          temp.add(
              x_y_sum
                  .get(key)
                  .multiply(x_y_sum.get(key).transpose())
                  .scalarMultiply(1.0 / n_y.get(key)));
    }
    H = temp.subtract(x_sum.multiply(x_sum.transpose()).scalarMultiply(1.0 / x.size()));

    RealMatrix E = new Array2DRowRealMatrix(inputDim, inputDim);
    for (RealMatrix m : x) {
      E = E.add(m.multiply(m.transpose()));
    }
    E = E.subtract(temp);

    List<Integer> zeroColumns = findZeroColumns(E);
    E = removeZeroColumns(E, zeroColumns);
    H = removeZeroColumns(H, zeroColumns);

    Matrix JE = new Matrix(E.getData());
    Matrix JH = new Matrix(H.getData());

    if (JE.rank() < JE.getRowDimension()) {
      Log.write(this, "Some error occurred (E matrix is singular)");
      return -1;
    } else {
      double lambda;
      if (useEigenvalues) {
        Matrix L = JE.inverse().times(JH);
        double[] eigs = L.eig().getRealEigenvalues();
        Arrays.sort(eigs);

        lambda = 1;
        int nonNullEigs = n_y.keySet().size() - 1;
        for (int i = eigs.length - nonNullEigs; i < eigs.length; i++) {
          if (Math.abs(eigs[i]) < zeroThreshold) {
            Log.write(this, "Some error occurred (E matrix has too many null eigenvalues)");
            return -1;
          }
          lambda *= 1.0 / (1.0 + eigs[i]);
        }
      } else {
        Matrix sum = JE.plus(JH);
        if (sum.rank() < sum.getRowDimension()) {
          Log.write(this, "Some error occourred (E+H is singular");
          return -1;
        }
        lambda = JE.det() / sum.det();
      }

      return Math.sqrt(1 - lambda);
    }
  }