コード例 #1
0
    @Override
    public void reduce(
        Text key, Iterable<jBLASArrayWritable> inputs, WeightContributions.Context context)
        throws IOException, InterruptedException {
      DoubleMatrix w_cont = new DoubleMatrix(),
          hb_cont = new DoubleMatrix(),
          vb_cont = new DoubleMatrix(),
          weights = null,
          hbias = null,
          vbias = null;

      ArrayList<DoubleMatrix> chainList = new ArrayList<DoubleMatrix>();

      ArrayList<DoubleMatrix> output_array = new ArrayList<DoubleMatrix>();

      int count = 0;

      for (jBLASArrayWritable input : inputs) {
        ArrayList<DoubleMatrix> data = input.getData();
        w_cont.copy(data.get(0));
        hb_cont.copy(data.get(1));
        vb_cont.copy(data.get(3));

        // save list of all hidden chains for updates to batch files in phase 3
        chainList.add(new DoubleMatrix(data.get(2).toArray2()));

        if (weights == null) {
          weights = DoubleMatrix.zeros(w_cont.rows, w_cont.columns);
          hbias = DoubleMatrix.zeros(hb_cont.rows, hb_cont.columns);
          vbias = DoubleMatrix.zeros(vb_cont.rows, vb_cont.columns);
        }

        // sum weight contributions
        weights.addi(w_cont);
        hbias.addi(hb_cont);
        vbias.addi(vb_cont);
        count++;
      }

      output_array.add(weights.div(count));
      output_array.add(hbias.div(count));
      output_array.add(vbias.div(count));
      output_array.addAll(chainList);

      jBLASArrayWritable outputmatrix = new jBLASArrayWritable(output_array);
      context.write(key, outputmatrix);
    }
コード例 #2
0
 public DoubleMatrix sample_h_from_v(DoubleMatrix v0, DoubleMatrix phase) {
   DoubleMatrix activations = propup(v0);
   phase.copy(activations);
   return MatrixMath.binom(activations);
 }
コード例 #3
0
    // TODO: DOUBLECHECK EVERYTHING
    @Override
    public void map(Text key, jBLASArrayWritable input, GibbsSampling.Context context)
        throws IOException, InterruptedException {
      /* *******************************************************************/
      /* initialize all memory we're going to use during the process		 */
      long start_time = System.nanoTime();
      ArrayList<DoubleMatrix> data = input.getData();
      label = data.get(4);
      v_data = data.get(5);

      // check to see if we are in the first layer or there are layers beneath us we must sample
      // from
      if (data.size() > 6) {
        int prelayer = (data.size() - 6) / 3;
        DoubleMatrix[] preWeights = new DoubleMatrix[prelayer],
            preHbias = new DoubleMatrix[prelayer],
            preVbias = new DoubleMatrix[prelayer];
        for (int i = 0; i < prelayer; i++) {
          preWeights[i] = data.get(6 + i * 3);
          preHbias[i] = data.get(7 + i * 3);
          preVbias[i] = data.get(8 + i * 3);
        }
        DoubleMatrix vnew = null;
        for (int i = 0; i < prelayer; i++) {
          weights = preWeights[i];
          vbias = preVbias[i];
          hbias = preHbias[i];
          vnew = sample_h_from_v(i == 0 ? v_data : vnew);
        }
        v_data = vnew;
      }

      weights = data.get(0);
      hbias = data.get(1);
      hiddenChain = data.get(2);
      vbias = data.get(3);

      // check if we need to attach labels to the observed variables
      if (vbias.columns != v_data.columns) {
        DoubleMatrix labels = DoubleMatrix.zeros(1, classCount);
        int labelNum = (new Double(label.get(0))).intValue();
        labels.put(labelNum, 1.0);
        v_data = DoubleMatrix.concatHorizontally(v_data, labels);
      }

      w1 = DoubleMatrix.zeros(weights.rows, weights.columns);
      hb1 = DoubleMatrix.zeros(hbias.rows, hbias.columns);
      vb1 = DoubleMatrix.zeros(vbias.rows, vbias.columns);

      /* ********************************************************************/
      // sample hidden state to get positive phase
      // if empty, use it as the start of the chain
      // or use persistent hidden state from pCD

      DoubleMatrix phaseSample = sample_h_from_v(v_data);
      h1_data = new DoubleMatrix();
      v1_data = new DoubleMatrix();

      if (hiddenChain == null) {
        data.set(2, new DoubleMatrix(hbias.rows, hbias.columns));
        hiddenChain = data.get(2);
        hiddenChain.copy(phaseSample);
        h1_data.copy(phaseSample);
      } else {
        h1_data.copy(hiddenChain);
      }
      // run Gibbs chain for k steps
      for (int j = 0; j < gibbsSteps; j++) {
        v1_data.copy(sample_v_from_h(h1_data));
        h1_data.copy(sample_h_from_v(v1_data));
      }
      DoubleMatrix hprob = propup(v1_data);
      weight_contribution(hiddenChain, v_data, hprob, v1_data);
      hiddenChain.copy(h1_data);

      data.get(0).copy(w1);
      data.get(1).copy(hb1);
      data.get(2).copy(hiddenChain);
      data.get(3).copy(vb1);

      jBLASArrayWritable outputmatrix = new jBLASArrayWritable(data);
      context.write(key, outputmatrix);
      log.info("Job completed in: " + (System.nanoTime() - start_time) / (1E6) + " ms");
    }