Ejemplo n.º 1
0
  /**
   * Process one input sample. This method is called by outer loop code outside the nupic-engine. We
   * use this instead of the nupic engine compute() because our inputs and outputs aren't fixed size
   * vectors of reals.
   *
   * @param recordNum Record number of this input pattern. Record numbers should normally increase
   *     sequentially by 1 each time unless there are missing records in the dataset. Knowing this
   *     information insures that we don't get confused by missing records.
   * @param classification {@link Map} of the classification information: bucketIdx: index of the
   *     encoder bucket actValue: actual value going into the encoder
   * @param patternNZ list of the active indices from the output below
   * @param learn if true, learn this sample
   * @param infer if true, perform inference
   * @return dict containing inference results, there is one entry for each step in steps, where the
   *     key is the number of steps, and the value is an array containing the relative likelihood
   *     for each bucketIdx starting from bucketIdx 0.
   *     <p>There is also an entry containing the average actual value to use for each bucket. The
   *     key is 'actualValues'.
   *     <p>for example: { 1 : [0.1, 0.3, 0.2, 0.7], 4 : [0.2, 0.4, 0.3, 0.5], 'actualValues': [1.5,
   *     3,5, 5,5, 7.6], }
   */
  @SuppressWarnings("unchecked")
  public <T> Classification<T> compute(
      int recordNum,
      Map<String, Object> classification,
      int[] patternNZ,
      boolean learn,
      boolean infer) {
    Classification<T> retVal = new Classification<T>();
    List<T> actualValues = (List<T>) this.actualValues;

    // Save the offset between recordNum and learnIteration if this is the first
    // compute
    if (recordNumMinusLearnIteration == -1) {
      recordNumMinusLearnIteration = recordNum - learnIteration;
    }

    // Update the learn iteration
    learnIteration = recordNum - recordNumMinusLearnIteration;

    if (verbosity >= 1) {
      System.out.println(String.format("\n%s: compute ", g_debugPrefix));
      System.out.println(" recordNum: " + recordNum);
      System.out.println(" learnIteration: " + learnIteration);
      System.out.println(String.format(" patternNZ(%d): ", patternNZ.length, patternNZ));
      System.out.println(" classificationIn: " + classification);
    }

    patternNZHistory.append(new Tuple(learnIteration, patternNZ));

    // ------------------------------------------------------------------------
    // Inference:
    // For each active bit in the activationPattern, get the classification
    // votes
    //
    // Return value dict. For buckets which we don't have an actual value
    // for yet, just plug in any valid actual value. It doesn't matter what
    // we use because that bucket won't have non-zero likelihood anyways.
    if (infer) {
      // NOTE: If doing 0-step prediction, we shouldn't use any knowledge
      //		 of the classification input during inference.
      Object defaultValue = null;
      if (steps.get(0) == 0) {
        defaultValue = 0;
      } else {
        defaultValue = classification.get("actValue");
      }

      T[] actValues = (T[]) new Object[this.actualValues.size()];
      for (int i = 0; i < actualValues.size(); i++) {
        actValues[i] = (T) (actualValues.get(i) == null ? defaultValue : actualValues.get(i));
      }

      retVal.setActualValues(actValues);

      // For each n-step prediction...
      for (int nSteps : steps.toArray()) {
        // Accumulate bucket index votes and actValues into these arrays
        double[] sumVotes = new double[maxBucketIdx + 1];
        double[] bitVotes = new double[maxBucketIdx + 1];

        for (int bit : patternNZ) {
          Tuple key = new Tuple(bit, nSteps);
          BitHistory history = activeBitHistory.get(key);
          if (history == null) continue;

          history.infer(learnIteration, bitVotes);

          sumVotes = ArrayUtils.d_add(sumVotes, bitVotes);
        }

        // Return the votes for each bucket, normalized
        double total = ArrayUtils.sum(sumVotes);
        if (total > 0) {
          sumVotes = ArrayUtils.divide(sumVotes, total);
        } else {
          // If all buckets have zero probability then simply make all of the
          // buckets equally likely. There is no actual prediction for this
          // timestep so any of the possible predictions are just as good.
          if (sumVotes.length > 0) {
            Arrays.fill(sumVotes, 1.0 / (double) sumVotes.length);
          }
        }

        retVal.setStats(nSteps, sumVotes);
      }
    }

    // ------------------------------------------------------------------------
    // Learning:
    // For each active bit in the activationPattern, store the classification
    // info. If the bucketIdx is None, we can't learn. This can happen when the
    // field is missing in a specific record.
    if (learn && classification.get("bucketIdx") != null) {
      // Get classification info
      int bucketIdx = (int) classification.get("bucketIdx");
      Object actValue = classification.get("actValue");

      // Update maxBucketIndex
      maxBucketIdx = (int) Math.max(maxBucketIdx, bucketIdx);

      // Update rolling average of actual values if it's a scalar. If it's
      // not, it must be a category, in which case each bucket only ever
      // sees one category so we don't need a running average.
      while (maxBucketIdx > actualValues.size() - 1) {
        actualValues.add(null);
      }
      if (actualValues.get(bucketIdx) == null) {
        actualValues.set(bucketIdx, (T) actValue);
      } else {
        if (Number.class.isAssignableFrom(actValue.getClass())) {
          Double val =
              ((1.0 - actValueAlpha) * ((Number) actualValues.get(bucketIdx)).doubleValue()
                  + actValueAlpha * ((Number) actValue).doubleValue());
          actualValues.set(bucketIdx, (T) val);
        } else {
          actualValues.set(bucketIdx, (T) actValue);
        }
      }

      // Train each pattern that we have in our history that aligns with the
      // steps we have in steps
      int nSteps = -1;
      int iteration = 0;
      int[] learnPatternNZ = null;
      for (int n : steps.toArray()) {
        nSteps = n;
        // Do we have the pattern that should be assigned to this classification
        // in our pattern history? If not, skip it
        boolean found = false;
        for (Tuple t : patternNZHistory) {
          iteration = (int) t.get(0);
          learnPatternNZ = (int[]) t.get(1);
          if (iteration == learnIteration - nSteps) {
            found = true;
            break;
          }
          iteration++;
        }
        if (!found) continue;

        // Store classification info for each active bit from the pattern
        // that we got nSteps time steps ago.
        for (int bit : learnPatternNZ) {
          // Get the history structure for this bit and step
          Tuple key = new Tuple(bit, nSteps);
          BitHistory history = activeBitHistory.get(key);
          if (history == null) {
            activeBitHistory.put(key, history = new BitHistory(this, bit, nSteps));
          }
          history.store(learnIteration, bucketIdx);
        }
      }
    }

    if (infer && verbosity >= 1) {
      System.out.println(" inference: combined bucket likelihoods:");
      System.out.println(
          "   actual bucket values: " + Arrays.toString((T[]) retVal.getActualValues()));

      for (int key : retVal.stepSet()) {
        if (retVal.getActualValue(key) == null) continue;

        Object[] actual = new Object[] {(T) retVal.getActualValue(key)};
        System.out.println(String.format("  %d steps: ", key, pFormatArray(actual)));
        int bestBucketIdx = retVal.getMostProbableBucketIndex(key);
        System.out.println(
            String.format(
                "   most likely bucket idx: %d, value: %s ",
                bestBucketIdx, retVal.getActualValue(bestBucketIdx)));
      }
    }

    return retVal;
  }
  @Override
  public CLAClassifier deserialize(JsonParser jp, DeserializationContext ctxt)
      throws IOException, JsonProcessingException {

    ObjectCodec oc = jp.getCodec();
    JsonNode node = oc.readTree(jp);

    CLAClassifier retVal = new CLAClassifier();
    retVal.alpha = node.get("alpha").asDouble();
    retVal.actValueAlpha = node.get("actValueAlpha").asDouble();
    retVal.learnIteration = node.get("learnIteration").asInt();
    retVal.recordNumMinusLearnIteration = node.get("recordNumMinusLearnIteration").asInt();
    retVal.maxBucketIdx = node.get("maxBucketIdx").asInt();

    String[] steps = node.get("steps").asText().split(",");
    TIntList t = new TIntArrayList();
    for (String step : steps) {
      t.add(Integer.parseInt(step));
    }
    retVal.steps = t;

    String[] tupleStrs = node.get("patternNZHistory").asText().split(";");
    Deque<Tuple> patterns = new Deque<Tuple>(tupleStrs.length);
    for (String tupleStr : tupleStrs) {
      String[] tupleParts = tupleStr.split("-");
      int iteration = Integer.parseInt(tupleParts[0]);
      String pattern = tupleParts[1].substring(1, tupleParts[1].indexOf("]")).trim();
      String[] indexes = pattern.split(",");
      int[] indices = new int[indexes.length];
      for (int i = 0; i < indices.length; i++) {
        indices[i] = Integer.parseInt(indexes[i].trim());
      }
      Tuple tup = new Tuple(iteration, indices);
      patterns.append(tup);
    }
    retVal.patternNZHistory = patterns;

    Map<Tuple, BitHistory> bitHistoryMap = new HashMap<Tuple, BitHistory>();
    String[] bithists = node.get("activeBitHistory").asText().split(";");
    for (String bh : bithists) {
      String[] parts = bh.split("-");

      String[] left = parts[0].split(",");
      Tuple iteration =
          new Tuple(Integer.parseInt(left[0].trim()), Integer.parseInt(left[1].trim()));

      BitHistory bitHistory = new BitHistory();
      String[] right = parts[1].split("=");
      bitHistory.id = right[0].trim();

      TDoubleList dubs = new TDoubleArrayList();
      String[] stats = right[1].substring(1, right[1].indexOf("}")).trim().split(",");
      for (int i = 0; i < stats.length; i++) {
        dubs.add(Double.parseDouble(stats[i].trim()));
      }
      bitHistory.stats = dubs;

      bitHistory.lastTotalUpdate = Integer.parseInt(right[2].trim());

      bitHistoryMap.put(iteration, bitHistory);
    }
    retVal.activeBitHistory = bitHistoryMap;

    ArrayNode jn = (ArrayNode) node.get("actualValues");
    List<Object> l = new ArrayList<Object>();
    for (int i = 0; i < jn.size(); i++) {
      JsonNode n = jn.get(i);
      try {
        double d = Double.parseDouble(n.asText().trim());
        l.add(d);
      } catch (Exception e) {
        l.add(n.asText().trim());
      }
    }
    retVal.actualValues = l;

    // Go back and set the classifier on the BitHistory objects
    for (Tuple tuple : bitHistoryMap.keySet()) {
      bitHistoryMap.get(tuple).classifier = retVal;
    }

    return retVal;
  }