/** * 图象压缩 * * @throws IOException */ public static void part2() throws IOException { tic(); logger.info("加载数据...\n"); String path = COURSE_ML_PATH + "/ex7/bird_small.png"; MatrixImage mi = ImageLoader.load(path); FloatMatrix X = mi.getMatrix(); logger.info("模型初始化...\n"); KMeans kMeans = new KMeans(X, 16); logger.info("执行训练...\n"); FloatMatrix centroids = kMeans.run(10); logger.info("运行完毕.\n聚类中心如下:\n{}\n", centroids); logger.info("图象对比...\n"); FloatMatrix indices = kMeans.findClosestCentroids(); int[] index = indices.toIntArray(); FloatMatrix result = centroids.getRows(index); ImShow.show(mi); mi.setMatrix(result); ImShow.show(mi); toc(); }
/** * Strips the dataset down to the specified labels and remaps them * * @param labels the labels to strip down to */ public void filterAndStrip(int[] labels) { FloatDataSet filtered = filterBy(labels); List<Integer> newLabels = new ArrayList<>(); // map new labels to index according to passed in labels Map<Integer, Integer> labelMap = new HashMap<>(); for (int i = 0; i < labels.length; i++) labelMap.put(labels[i], i); // map examples for (int i = 0; i < filtered.numExamples(); i++) { int o2 = filtered.get(i).outcome(); int outcome = labelMap.get(o2); newLabels.add(outcome); } FloatMatrix newLabelMatrix = new FloatMatrix(filtered.numExamples(), labels.length); if (newLabelMatrix.rows != newLabels.size()) throw new IllegalStateException("Inconsistent label sizes"); for (int i = 0; i < newLabelMatrix.rows; i++) { Integer i2 = newLabels.get(i); if (i2 == null) throw new IllegalStateException("Label not found on row " + i); FloatMatrix newRow = MatrixUtil.toOutcomeVectorFloat(i2, labels.length); newLabelMatrix.putRow(i, newRow); } setFirst(filtered.getFirst()); setSecond(newLabelMatrix); }
@Override public FloatMatrix getInitTheta() { FloatMatrix x = FloatMatrix.rand(rows, features); FloatMatrix theta = FloatMatrix.rand(columns, features); return MatrixUtil.merge(x.data, theta.data); }
/** Returns matrices of the right size for either binary or unary (terminal) classification */ FloatMatrix randomClassificationMatrix() { // Leave the bias column with 0 values float range = 1.0f / (float) (Math.sqrt((float) numHidden)); FloatMatrix ret = FloatMatrix.zeros(numOuts, numHidden + 1); FloatMatrix insert = MatrixUtil.rand(numOuts, numHidden, -range, range, rng); ret.put(interval(0, numOuts), interval(0, numHidden), insert); return SimpleBlas.scal(scalingForInit, ret); }
public void normalizeZeroMeanZeroUnitVariance() { FloatMatrix columnMeans = getFirst().columnMeans(); FloatMatrix columnStds = MatrixUtil.columnStdDeviation(getFirst()); setFirst(getFirst().subiRowVector(columnMeans)); columnStds.addi(1e-6f); setFirst(getFirst().diviRowVector(columnStds)); }
public static FloatDataSet load(File path) throws IOException { DataInputStream bis = new DataInputStream(new BufferedInputStream(new FileInputStream(path))); FloatMatrix x = new FloatMatrix(1, 1); FloatMatrix y = new FloatMatrix(1, 1); x.in(bis); y.in(bis); bis.close(); return new FloatDataSet(x, y); }
FloatMatrix randomTransformMatrix() { FloatMatrix binary = new FloatMatrix(numHidden, numHidden * 2 + 1); // bias column values are initialized zero FloatMatrix block = randomTransformBlock(); binary.put(interval(0, block.rows), interval(0, block.columns), block); binary.put( interval(0, block.rows), interval(numHidden, numHidden + block.columns), randomTransformBlock()); return SimpleBlas.scal(scalingForInit, binary); }
private FloatTensor getFloatTensorGradient( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) { int size = deltaFull.length; FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { Wt_df.setSlice( slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose())); } return Wt_df; }
@Override public FloatMatrix function(FloatMatrix m) { FloatMatrix md = m.dup(); float[] data = md.toArray(); float x; for (int i = 0; i < data.length; i++) { x = data[i] * 0.5f; data[i] = x / (float) Math.sqrt(x * x + 1.0) * 0.5f + 0.5f; } return md; }
/** * Loads the google binary model Credit to: * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java * * @param path path to model * @throws IOException */ public static Word2Vec loadGoogleModel(String path) throws IOException { DataInputStream dis = null; BufferedInputStream bis = null; double len = 0; float vector = 0; Word2Vec ret = new Word2Vec(); Index wordIndex = new Index(); FloatMatrix wordVectors = null; try { bis = new BufferedInputStream( path.endsWith(".gz") ? new GZIPInputStream(new FileInputStream(path)) : new FileInputStream(path)); dis = new DataInputStream(bis); Map<String, FloatMatrix> wordMap = new HashMap<>(); // number of words int words = Integer.parseInt(readString(dis)); // word vector size int size = Integer.parseInt(readString(dis)); wordVectors = new FloatMatrix(words, size); String word; float[] vectors = null; for (int i = 0; i < words; i++) { word = readString(dis); log.info("Loaded " + word); vectors = new float[size]; len = 0; for (int j = 0; j < size; j++) { vector = readFloat(dis); len += vector * vector; vectors[j] = vector; } len = Math.sqrt(len); for (int j = 0; j < size; j++) { vectors[j] /= len; } wordIndex.add(word); wordVectors.putRow(i, new FloatMatrix(vectors)); } } finally { bis.close(); dis.close(); } ret.setWordIndex(wordIndex); ret.setSyn0(wordVectors); return ret; }
/** * Given a sequence of Iterators over a set of matrices, fill in all of the matrices with the * entries in the theta vector. Errors are thrown if the theta vector does not exactly fill the * matrices. */ public void setParams(FloatMatrix theta, Iterator<? extends FloatMatrix>... matrices) { int index = 0; for (Iterator<? extends FloatMatrix> matrixIterator : matrices) { while (matrixIterator.hasNext()) { FloatMatrix matrix = matrixIterator.next(); for (int i = 0; i < matrix.length; ++i) { matrix.put(i, theta.get(index)); ++index; } } } if (index != theta.length) { throw new AssertionError("Did not entirely use the theta vector"); } }
/** * Sample a dataset * * @param numSamples the number of samples to getFromOrigin * @param rng the rng to use * @param withReplacement whether to allow duplicates (only tracked by example row number) * @return the sample dataset */ public FloatDataSet sample(int numSamples, RandomGenerator rng, boolean withReplacement) { if (numSamples >= numExamples()) return this; else { FloatMatrix examples = new FloatMatrix(numSamples, getFirst().columns); FloatMatrix outcomes = new FloatMatrix(numSamples, numOutcomes()); Set<Integer> added = new HashSet<Integer>(); for (int i = 0; i < numSamples; i++) { int picked = rng.nextInt(numExamples()); if (!withReplacement) while (added.contains(picked)) { picked = rng.nextInt(numExamples()); } examples.putRow(i, get(picked).getFirst()); outcomes.putRow(i, get(picked).getSecond()); } return new FloatDataSet(examples, outcomes); } }
public static FloatDataSet merge(List<FloatDataSet> data) { if (data.isEmpty()) throw new IllegalArgumentException("Unable to merge empty dataset"); FloatDataSet first = data.get(0); int numExamples = totalExamples(data); FloatMatrix in = new FloatMatrix(numExamples, first.getFirst().columns); FloatMatrix out = new FloatMatrix(numExamples, first.getSecond().columns); int count = 0; for (int i = 0; i < data.size(); i++) { FloatDataSet d1 = data.get(i); for (int j = 0; j < d1.numExamples(); j++) { FloatDataSet example = d1.get(j); in.putRow(count, example.getFirst()); out.putRow(count, example.getSecond()); count++; } } return new FloatDataSet(in, out); }
@Override protected void compute() { if (aEnd - aStart > minSize && bEnd - bStart > minSize) { final int aMiddle = aStart + (aEnd - aStart) / 2; final int bMiddle = bStart + (bEnd - bStart) / 2; invokeAll( new MulitplyPartly(a, b, result, aStart, aMiddle, bStart, bMiddle), new MulitplyPartly(a, b, result, aMiddle, aEnd, bMiddle, bEnd), new MulitplyPartly(a, b, result, aStart, aMiddle, bMiddle, bEnd), new MulitplyPartly(a, b, result, aMiddle, aEnd, bStart, bMiddle)); } else { FloatMatrix x = a.get(new IntervalRange(aStart, aEnd), new IntervalRange(0, a.columns)) .mmul(b.get(new IntervalRange(0, b.rows), new IntervalRange(bStart, bEnd))); result.put(x, aStart, bStart); } }
public FloatMatrix normalizeRatings() { int[] indices; FloatMatrix yMean = FloatMatrix.zeros(rows, 1); FloatMatrix yNorm = FloatMatrix.zeros(rows, columns); for (int i = 0; i < rows; i++) { indices = r.getRow(i).eq(1).findIndices(); yMean.put(i, y.getRow(i).get(indices).mean()); yNorm.getRow(i).put(indices, y.getRow(i).get(indices).sub(yMean.get(i))); } return yMean; }
/** * This is the method to call for assigning labels and node vectors to the Tree. After calling * this, each of the non-leaf nodes will have the node vector and the predictions of their classes * assigned to that subtree's node. */ public void forwardPropagateTree(Tree tree) { FloatMatrix nodeVector; FloatMatrix classification; if (tree.isLeaf()) { // We do nothing for the leaves. The preterminals will // calculate the classification for this word/tag. In fact, the // recursion should not have gotten here (unless there are // degenerate trees of just one leaf) throw new AssertionError("We should not have reached leaves in forwardPropagate"); } else if (tree.isPreTerminal()) { classification = getUnaryClassification(tree.label()); String word = tree.children().get(0).value(); FloatMatrix wordVector = getFeatureVector(word); if (wordVector == null) { wordVector = featureVectors.get(UNKNOWN_FEATURE); } nodeVector = activationFunction.apply(wordVector); } else if (tree.children().size() == 1) { throw new AssertionError( "Non-preterminal nodes of size 1 should have already been collapsed"); } else if (tree.children().size() == 2) { Tree left = tree.firstChild(), right = tree.lastChild(); forwardPropagateTree(left); forwardPropagateTree(right); String leftCategory = tree.children().get(0).label(); String rightCategory = tree.children().get(1).label(); FloatMatrix W = getBinaryTransform(leftCategory, rightCategory); classification = getBinaryClassification(leftCategory, rightCategory); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); if (useFloatTensors) { FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory); FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector); FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn); nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut)); } else nodeVector = activationFunction.apply(W.mmul(childrenVector)); } else { throw new AssertionError("Tree not correctly binarized"); } FloatMatrix inputWithBias = appendBias(nodeVector); FloatMatrix preAct = classification.mmul(inputWithBias); FloatMatrix predictions = outputActivation.apply(preAct); tree.setPrediction(predictions); tree.setVector(nodeVector); }
private FloatMatrix computeFloatTensorDeltaDown( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector, FloatMatrix W, FloatTensor Wt) { FloatMatrix WTDelta = W.transpose().mmul(deltaFull); FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2)); int size = deltaFull.length; FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector); deltaFloatTensor = deltaFloatTensor.add( Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector)); } return deltaFloatTensor.add(WTDeltaNoBias); }
public FloatMatrix getValueGradient(int iterations) { // We use TreeMap for each of these so that they stay in a // canonical sorted order // TODO: factor out the initialization routines // binaryTD stands for Transform Derivatives final MultiDimensionalMap<String, String, FloatMatrix> binaryTD = MultiDimensionalMap.newTreeBackedMap(); // the derivatives of the FloatTensors for the binary nodes final MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD = MultiDimensionalMap.newTreeBackedMap(); // binaryCD stands for Classification Derivatives final MultiDimensionalMap<String, String, FloatMatrix> binaryCD = MultiDimensionalMap.newTreeBackedMap(); // unaryCD stands for Classification Derivatives final Map<String, FloatMatrix> unaryCD = new TreeMap<>(); // word vector derivatives final Map<String, FloatMatrix> wordVectorD = new TreeMap<>(); for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry : binaryTransform.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols)); } if (!combineClassification) { for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry : binaryClassification.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols)); } } if (useFloatTensors) { for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry : binaryFloatTensors.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns; int numSlices = entry.getValue().slices(); binaryFloatTensorTD.put( entry.getFirstKey(), entry.getSecondKey(), new FloatTensor(numRows, numCols, numSlices)); } } for (Map.Entry<String, FloatMatrix> entry : unaryClassification.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; unaryCD.put(entry.getKey(), new FloatMatrix(numRows, numCols)); } for (Map.Entry<String, FloatMatrix> entry : featureVectors.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; wordVectorD.put(entry.getKey(), new FloatMatrix(numRows, numCols)); } final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>(); Parallelization.iterateInParallel( trainingTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { Tree trainingTree = new Tree(currentItem); trainingTree.connect(new ArrayList<>(currentItem.children())); // this will attach the error vectors and the node vectors // to each node in the tree forwardPropagateTree(trainingTree); forwardPropTrees.add(trainingTree); } }, rnTnActorSystem); // TODO: we may find a big speedup by separating the derivatives and then summing final AtomicDouble error = new AtomicDouble(0); Parallelization.iterateInParallel( forwardPropTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { backpropDerivativesAndError( currentItem, binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD); error.addAndGet(currentItem.errorSum()); } }, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) {} }, rnTnActorSystem, new Object[] {binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD}); // scale the error by the number of sentences so that the // regularization isn't drowned out for large training batchs float scale = (1.0f / trainingTrees.size()); value = error.floatValue() * scale; value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix); value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification); value += scaleAndRegularizeFloatTensor( binaryFloatTensorTD, binaryFloatTensors, scale, regTransformFloatTensor); value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification); value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector); FloatMatrix derivative = MatrixUtil.toFlattenedFloat( getNumParameters(), binaryTD.values().iterator(), binaryCD.values().iterator(), binaryFloatTensorTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()); if (paramAdaGrad == null) paramAdaGrad = new AdaGradFloat(1, derivative.columns); derivative.muli(paramAdaGrad.getLearningRates(derivative)); return derivative; }
/** * Adds a feature for each example on to the current feature vector * * @param toAdd the feature vector to add */ public void addFeatureVector(FloatMatrix toAdd) { setFirst(FloatMatrix.concatHorizontally(getFirst(), toAdd)); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, FloatMatrix> binaryTD, MultiDimensionalMap<String, String, FloatMatrix> binaryCD, MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD, Map<String, FloatMatrix> unaryCD, Map<String, FloatMatrix> wordVectorD, FloatMatrix deltaUp) { if (tree.isLeaf()) { return; } FloatMatrix currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class FloatMatrix goldLabel = new FloatMatrix(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { goldLabel.put(goldClass, 1.0f); } Float nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0f; FloatMatrix predictions = tree.prediction(); // If this is an unlabeled class, set deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class FloatMatrix deltaClass = goldClass >= 0 ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel)) : new FloatMatrix(predictions.rows, predictions.columns); FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose()); float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum()); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector); FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); wordVectorD.put(word, wordVectorD.get(word).add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector); FloatMatrix deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 FloatMatrix add = binaryTD.get(leftCategory, rightCategory); FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); FloatMatrix deltaDown; if (useFloatTensors) { FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector); binaryFloatTensorTD.put( leftCategory, rightCategory, binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeFloatTensorDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryFloatTensor(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } FloatMatrix leftDerivative = activationFunction.apply(leftVector); FloatMatrix rightDerivative = activationFunction.apply(rightVector); FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1)); FloatMatrix rightDeltaDown = deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }
public FloatDataSet() { this(FloatMatrix.zeros(1), FloatMatrix.zeros(1)); }
/** * The feature to add, and the example/row number * * @param feature the feature vector to add * @param example the number of the example to append to */ public void addFeatureVector(FloatMatrix feature, int example) { getFirst().putRow(example, FloatMatrix.concatHorizontally(getFirst().getRow(example), feature)); }
public static FloatDataSet empty() { return new FloatDataSet(FloatMatrix.zeros(1), FloatMatrix.zeros(1)); }
private void init() { if (rng == null) rng = new MersenneTwister(123); MultiDimensionalSet<String, String> binaryProductions = MultiDimensionalSet.hashSet(); if (simplifiedModel) { binaryProductions.add("", ""); } else { // TODO // figure out what binary productions we have in these trees // Note: the current sentiment training data does not actually // have any constituent labels throw new UnsupportedOperationException("Not yet implemented"); } Set<String> unaryProductions = new HashSet<>(); if (simplifiedModel) { unaryProductions.add(""); } else { // TODO // figure out what unary productions we have in these trees (preterminals only, after the // collapsing) throw new UnsupportedOperationException("Not yet implemented"); } identity = FloatMatrix.eye(numHidden); binaryTransform = MultiDimensionalMap.newTreeBackedMap(); binaryFloatTensors = MultiDimensionalMap.newTreeBackedMap(); binaryClassification = MultiDimensionalMap.newTreeBackedMap(); // When making a flat model (no semantic untying) the // basicCategory function will return the same basic category for // all labels, so all entries will map to the same matrix for (Pair<String, String> binary : binaryProductions) { String left = basicCategory(binary.getFirst()); String right = basicCategory(binary.getSecond()); if (binaryTransform.contains(left, right)) { continue; } binaryTransform.put(left, right, randomTransformMatrix()); if (useFloatTensors) { binaryFloatTensors.put(left, right, randomBinaryFloatTensor()); } if (!combineClassification) { binaryClassification.put(left, right, randomClassificationMatrix()); } } numBinaryMatrices = binaryTransform.size(); binaryTransformSize = numHidden * (2 * numHidden + 1); if (useFloatTensors) { binaryFloatTensorSize = numHidden * numHidden * numHidden * 4; } else { binaryFloatTensorSize = 0; } binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1); unaryClassification = new TreeMap<>(); // When making a flat model (no semantic untying) the // basicCategory function will return the same basic category for // all labels, so all entries will map to the same matrix for (String unary : unaryProductions) { unary = basicCategory(unary); if (unaryClassification.containsKey(unary)) { continue; } unaryClassification.put(unary, randomClassificationMatrix()); } binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1); numUnaryMatrices = unaryClassification.size(); unaryClassificationSize = numOuts * (numHidden + 1); featureVectors.put(UNKNOWN_FEATURE, randomWordVector()); numUnaryMatrices = unaryClassification.size(); unaryClassificationSize = numOuts * (numHidden + 1); classWeights = new HashMap<>(); }
public static FloatMatrix conv2d(FloatMatrix input, FloatMatrix kernel, Type type) { FloatMatrix xShape = new FloatMatrix(1, 2); xShape.put(0, input.rows); xShape.put(1, input.columns); FloatMatrix yShape = new FloatMatrix(1, 2); yShape.put(0, kernel.rows); yShape.put(1, kernel.columns); FloatMatrix zShape = xShape.add(yShape).sub(1); int retRows = (int) zShape.get(0); int retCols = (int) zShape.get(1); ComplexFloatMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols); ComplexFloatMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols); ComplexFloatMatrix mul = fftKernel.mul(fftInput); ComplexFloatMatrix retComplex = complexInverseDisceteFourierTransform(mul); FloatMatrix ret = retComplex.getReal(); if (type == Type.VALID) { FloatMatrix validShape = xShape.subi(yShape).add(1); FloatMatrix start = zShape.sub(validShape).div(2); FloatMatrix end = start.add(validShape); if (start.get(0) < 1 || start.get(1) < 1) throw new IllegalStateException("Illegal row index " + start); if (end.get(0) < 1 || end.get(1) < 1) throw new IllegalStateException("Illegal column index " + end); ret = ret.get( RangeUtils.interval((int) start.get(0), (int) end.get(0)), RangeUtils.interval((int) start.get(1), (int) end.get(1))); } return ret; }
@Override public Object compute(FloatMatrix params, int flag) { x = params.getRange(0, rows * features); FloatMatrix theta = params.getRange(rows * features, params.length); x = x.reshape(rows, features); theta = theta.reshape(columns, features); if (flag == 1 || flag == 3) { FloatMatrix M = MatrixFunctions.pow(x.mmul(theta.transpose()).sub(y), 2); this.cost = M.mul(r).columnSums().rowSums().get(0) / 2; if (lambda != 0) { float cost1 = (lambda / 2) * (MatrixFunctions.pow(theta, 2).columnSums().rowSums().get(0) + MatrixFunctions.pow(x, 2).columnSums().rowSums().get(0)); this.cost += cost1; } } if (flag == 2 || flag == 3) { FloatMatrix xGrad = FloatMatrix.zeros(x.rows, x.columns); FloatMatrix thetaGrad = FloatMatrix.zeros(theta.rows, theta.columns); int[] indices; FloatMatrix thetaTemp; FloatMatrix xTemp; FloatMatrix yTemp; for (int i = 0; i < rows; i++) { indices = r.getRow(i).eq(1).findIndices(); if (indices.length == 0) continue; thetaTemp = theta.getRows(indices); yTemp = y.getRow(i).get(indices); xGrad.putRow(i, x.getRow(i).mmul(thetaTemp.transpose()).sub(yTemp).mmul(thetaTemp)); } xGrad = xGrad.add(x.mmul(lambda)); for (int i = 0; i < columns; i++) { indices = r.getColumn(i).eq(1).findIndices(); if (indices.length == 0) continue; xTemp = x.getRows(indices); yTemp = y.getColumn(i).get(indices); thetaGrad.putRow( i, xTemp.mmul(theta.getRow(i).transpose()).sub(yTemp).transpose().mmul(xTemp)); } thetaGrad = thetaGrad.add(theta.mmul(lambda)); this.gradient = MatrixUtil.merge(xGrad.data, thetaGrad.data); } return flag == 1 ? cost : gradient; }
public void pmmuli(FloatMatrix self, FloatMatrix other, FloatMatrix result) { pool.invoke(new MulitplyPartly(self, other, result, 0, self.getRows(), 0, other.getColumns())); }