/** * Process one input sample. This method is called by outer loop code outside the nupic-engine. We * use this instead of the nupic engine compute() because our inputs and outputs aren't fixed size * vectors of reals. * * @param recordNum Record number of this input pattern. Record numbers should normally increase * sequentially by 1 each time unless there are missing records in the dataset. Knowing this * information insures that we don't get confused by missing records. * @param classification {@link Map} of the classification information: bucketIdx: index of the * encoder bucket actValue: actual value going into the encoder * @param patternNZ list of the active indices from the output below * @param learn if true, learn this sample * @param infer if true, perform inference * @return dict containing inference results, there is one entry for each step in steps, where the * key is the number of steps, and the value is an array containing the relative likelihood * for each bucketIdx starting from bucketIdx 0. * <p>There is also an entry containing the average actual value to use for each bucket. The * key is 'actualValues'. * <p>for example: { 1 : [0.1, 0.3, 0.2, 0.7], 4 : [0.2, 0.4, 0.3, 0.5], 'actualValues': [1.5, * 3,5, 5,5, 7.6], } */ @SuppressWarnings("unchecked") public <T> Classification<T> compute( int recordNum, Map<String, Object> classification, int[] patternNZ, boolean learn, boolean infer) { Classification<T> retVal = new Classification<T>(); List<T> actualValues = (List<T>) this.actualValues; // Save the offset between recordNum and learnIteration if this is the first // compute if (recordNumMinusLearnIteration == -1) { recordNumMinusLearnIteration = recordNum - learnIteration; } // Update the learn iteration learnIteration = recordNum - recordNumMinusLearnIteration; if (verbosity >= 1) { System.out.println(String.format("\n%s: compute ", g_debugPrefix)); System.out.println(" recordNum: " + recordNum); System.out.println(" learnIteration: " + learnIteration); System.out.println(String.format(" patternNZ(%d): ", patternNZ.length, patternNZ)); System.out.println(" classificationIn: " + classification); } patternNZHistory.append(new Tuple(learnIteration, patternNZ)); // ------------------------------------------------------------------------ // Inference: // For each active bit in the activationPattern, get the classification // votes // // Return value dict. For buckets which we don't have an actual value // for yet, just plug in any valid actual value. It doesn't matter what // we use because that bucket won't have non-zero likelihood anyways. if (infer) { // NOTE: If doing 0-step prediction, we shouldn't use any knowledge // of the classification input during inference. Object defaultValue = null; if (steps.get(0) == 0) { defaultValue = 0; } else { defaultValue = classification.get("actValue"); } T[] actValues = (T[]) new Object[this.actualValues.size()]; for (int i = 0; i < actualValues.size(); i++) { actValues[i] = (T) (actualValues.get(i) == null ? defaultValue : actualValues.get(i)); } retVal.setActualValues(actValues); // For each n-step prediction... for (int nSteps : steps.toArray()) { // Accumulate bucket index votes and actValues into these arrays double[] sumVotes = new double[maxBucketIdx + 1]; double[] bitVotes = new double[maxBucketIdx + 1]; for (int bit : patternNZ) { Tuple key = new Tuple(bit, nSteps); BitHistory history = activeBitHistory.get(key); if (history == null) continue; history.infer(learnIteration, bitVotes); sumVotes = ArrayUtils.d_add(sumVotes, bitVotes); } // Return the votes for each bucket, normalized double total = ArrayUtils.sum(sumVotes); if (total > 0) { sumVotes = ArrayUtils.divide(sumVotes, total); } else { // If all buckets have zero probability then simply make all of the // buckets equally likely. There is no actual prediction for this // timestep so any of the possible predictions are just as good. if (sumVotes.length > 0) { Arrays.fill(sumVotes, 1.0 / (double) sumVotes.length); } } retVal.setStats(nSteps, sumVotes); } } // ------------------------------------------------------------------------ // Learning: // For each active bit in the activationPattern, store the classification // info. If the bucketIdx is None, we can't learn. This can happen when the // field is missing in a specific record. if (learn && classification.get("bucketIdx") != null) { // Get classification info int bucketIdx = (int) classification.get("bucketIdx"); Object actValue = classification.get("actValue"); // Update maxBucketIndex maxBucketIdx = (int) Math.max(maxBucketIdx, bucketIdx); // Update rolling average of actual values if it's a scalar. If it's // not, it must be a category, in which case each bucket only ever // sees one category so we don't need a running average. while (maxBucketIdx > actualValues.size() - 1) { actualValues.add(null); } if (actualValues.get(bucketIdx) == null) { actualValues.set(bucketIdx, (T) actValue); } else { if (Number.class.isAssignableFrom(actValue.getClass())) { Double val = ((1.0 - actValueAlpha) * ((Number) actualValues.get(bucketIdx)).doubleValue() + actValueAlpha * ((Number) actValue).doubleValue()); actualValues.set(bucketIdx, (T) val); } else { actualValues.set(bucketIdx, (T) actValue); } } // Train each pattern that we have in our history that aligns with the // steps we have in steps int nSteps = -1; int iteration = 0; int[] learnPatternNZ = null; for (int n : steps.toArray()) { nSteps = n; // Do we have the pattern that should be assigned to this classification // in our pattern history? If not, skip it boolean found = false; for (Tuple t : patternNZHistory) { iteration = (int) t.get(0); learnPatternNZ = (int[]) t.get(1); if (iteration == learnIteration - nSteps) { found = true; break; } iteration++; } if (!found) continue; // Store classification info for each active bit from the pattern // that we got nSteps time steps ago. for (int bit : learnPatternNZ) { // Get the history structure for this bit and step Tuple key = new Tuple(bit, nSteps); BitHistory history = activeBitHistory.get(key); if (history == null) { activeBitHistory.put(key, history = new BitHistory(this, bit, nSteps)); } history.store(learnIteration, bucketIdx); } } } if (infer && verbosity >= 1) { System.out.println(" inference: combined bucket likelihoods:"); System.out.println( " actual bucket values: " + Arrays.toString((T[]) retVal.getActualValues())); for (int key : retVal.stepSet()) { if (retVal.getActualValue(key) == null) continue; Object[] actual = new Object[] {(T) retVal.getActualValue(key)}; System.out.println(String.format(" %d steps: ", key, pFormatArray(actual))); int bestBucketIdx = retVal.getMostProbableBucketIndex(key); System.out.println( String.format( " most likely bucket idx: %d, value: %s ", bestBucketIdx, retVal.getActualValue(bestBucketIdx))); } } return retVal; }
@Override public CLAClassifier deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException, JsonProcessingException { ObjectCodec oc = jp.getCodec(); JsonNode node = oc.readTree(jp); CLAClassifier retVal = new CLAClassifier(); retVal.alpha = node.get("alpha").asDouble(); retVal.actValueAlpha = node.get("actValueAlpha").asDouble(); retVal.learnIteration = node.get("learnIteration").asInt(); retVal.recordNumMinusLearnIteration = node.get("recordNumMinusLearnIteration").asInt(); retVal.maxBucketIdx = node.get("maxBucketIdx").asInt(); String[] steps = node.get("steps").asText().split(","); TIntList t = new TIntArrayList(); for (String step : steps) { t.add(Integer.parseInt(step)); } retVal.steps = t; String[] tupleStrs = node.get("patternNZHistory").asText().split(";"); Deque<Tuple> patterns = new Deque<Tuple>(tupleStrs.length); for (String tupleStr : tupleStrs) { String[] tupleParts = tupleStr.split("-"); int iteration = Integer.parseInt(tupleParts[0]); String pattern = tupleParts[1].substring(1, tupleParts[1].indexOf("]")).trim(); String[] indexes = pattern.split(","); int[] indices = new int[indexes.length]; for (int i = 0; i < indices.length; i++) { indices[i] = Integer.parseInt(indexes[i].trim()); } Tuple tup = new Tuple(iteration, indices); patterns.append(tup); } retVal.patternNZHistory = patterns; Map<Tuple, BitHistory> bitHistoryMap = new HashMap<Tuple, BitHistory>(); String[] bithists = node.get("activeBitHistory").asText().split(";"); for (String bh : bithists) { String[] parts = bh.split("-"); String[] left = parts[0].split(","); Tuple iteration = new Tuple(Integer.parseInt(left[0].trim()), Integer.parseInt(left[1].trim())); BitHistory bitHistory = new BitHistory(); String[] right = parts[1].split("="); bitHistory.id = right[0].trim(); TDoubleList dubs = new TDoubleArrayList(); String[] stats = right[1].substring(1, right[1].indexOf("}")).trim().split(","); for (int i = 0; i < stats.length; i++) { dubs.add(Double.parseDouble(stats[i].trim())); } bitHistory.stats = dubs; bitHistory.lastTotalUpdate = Integer.parseInt(right[2].trim()); bitHistoryMap.put(iteration, bitHistory); } retVal.activeBitHistory = bitHistoryMap; ArrayNode jn = (ArrayNode) node.get("actualValues"); List<Object> l = new ArrayList<Object>(); for (int i = 0; i < jn.size(); i++) { JsonNode n = jn.get(i); try { double d = Double.parseDouble(n.asText().trim()); l.add(d); } catch (Exception e) { l.add(n.asText().trim()); } } retVal.actualValues = l; // Go back and set the classifier on the BitHistory objects for (Tuple tuple : bitHistoryMap.keySet()) { bitHistoryMap.get(tuple).classifier = retVal; } return retVal; }