/** {@inheritDoc} */ public double getLInfDistance(RealVector v) throws IllegalArgumentException { checkVectorDimensions(v.getDimension()); if (v instanceof OpenMapRealVector) { return getLInfDistance((OpenMapRealVector) v); } return getLInfDistance(v.getData()); }
/** {@inheritDoc} */ public RealVector solve(RealVector b) throws IllegalArgumentException, InvalidMatrixException { try { return solve((RealVectorImpl) b); } catch (ClassCastException cce) { final int m = lTData.length; if (b.getDimension() != m) { throw MathRuntimeException.createIllegalArgumentException( "vector length mismatch: got {0} but expected {1}", b.getDimension(), m); } final double[] x = b.getData(); // Solve LY = b for (int j = 0; j < m; j++) { final double[] lJ = lTData[j]; x[j] /= lJ[j]; final double xJ = x[j]; for (int i = j + 1; i < m; i++) { x[i] -= xJ * lJ[i]; } } // Solve LTX = Y for (int j = m - 1; j >= 0; j--) { x[j] /= lTData[j][j]; final double xJ = x[j]; for (int i = 0; i < j; i++) { x[i] -= xJ * lTData[i][j]; } } return new RealVectorImpl(x, false); } }
/** {@inheritDoc} */ public OpenMapRealVector add(RealVector v) throws IllegalArgumentException { checkVectorDimensions(v.getDimension()); if (v instanceof OpenMapRealVector) { return add((OpenMapRealVector) v); } return add(v.getData()); }
private static Intersection getIntersection(Ray ray, SphereObject obj, Model model) { RealMatrix transform = obj.getTransform(); final RealMatrix transformInverse = obj.getTransformInverse(); ray = ray.transform(transformInverse); Vector3D c = VectorUtils.toVector3D(obj.getCenter()); Vector3D p0 = VectorUtils.toVector3D(ray.getP0()); Vector3D p1 = VectorUtils.toVector3D(ray.getP1()); float a = (float) p1.dotProduct(p1); Vector3D p0c = p0.subtract(c); float b = (float) p1.dotProduct(p0c) * 2.0f; float cc = (float) (p0c.dotProduct(p0c)) - obj.getSize() * obj.getSize(); Double t = quadraticEquationRoot1(a, b, cc); if (t == null || t < EPSILON) { return new Intersection(false); } Intersection result = new Intersection(true); result.setObject(obj); final Vector3D p = p0.add(p1.scalarMultiply(t)); RealVector pv = VectorUtils.toRealVector(p); pv.setEntry(3, 1.0); RealVector pt = transform.preMultiply(pv); result.setP(VectorUtils.toVector3D(pt)); RealVector nv = pv.subtract(obj.getCenter()); final RealVector nvt = transformInverse.transpose().preMultiply(nv); result.setN(VectorUtils.toVector3D(nvt).normalize()); result.setDistance(t); return result; }
/** {@inheritDoc} */ public RealVector solve(RealVector b) { final int m = pivot.length; if (b.getDimension() != m) { throw new DimensionMismatchException(b.getDimension(), m); } if (singular) { throw new SingularMatrixException(); } final double[] bp = new double[m]; // Apply permutations to b for (int row = 0; row < m; row++) { bp[row] = b.getEntry(pivot[row]); } // Solve LY = b for (int col = 0; col < m; col++) { final double bpCol = bp[col]; for (int i = col + 1; i < m; i++) { bp[i] -= bpCol * lu[i][col]; } } // Solve UX = Y for (int col = m - 1; col >= 0; col--) { bp[col] /= lu[col][col]; final double bpCol = bp[col]; for (int i = 0; i < col; i++) { bp[i] -= bpCol * lu[i][col]; } } return new ArrayRealVector(bp, false); }
@Override public Label predict(Instance instance) { Label l = null; if (instance.getLabel() instanceof ClassificationLabel || instance.getLabel() == null) { // ----------------- declare variables ------------------ double lambda = 0.0; RealVector x_instance = new ArrayRealVector(matrixX.getColumnDimension(), 0); double result = 0.0; // -------------------------- initialize xi ------------------------- for (int idx = 0; idx < matrixX.getColumnDimension(); idx++) { x_instance.setEntry(idx, instance.getFeatureVector().get(idx + 1)); } // ------------------ get lambda ----------------------- for (int j = 0; j < alpha.getDimension(); j++) { lambda += alpha.getEntry(j) * kernelFunction(matrixX.getRowVector(j), x_instance); } // ----------------- make prediction ----------------- Sigmoid g = new Sigmoid(); // helper function result = g.value(lambda); l = new ClassificationLabel(result < 0.5 ? 0 : 1); } else { System.out.println("label type error!"); } return l; }
/** * Serialize a {@link RealVector}. * * <p>This method is intended to be called from within a private <code>writeObject</code> method * (after a call to <code>oos.defaultWriteObject()</code>) in a class that has a {@link * RealVector} field, which should be declared <code>transient</code>. This way, the default * handling does not serialize the vector (the {@link RealVector} interface is not serializable by * default) but this method does serialize it specifically. * * <p>The following example shows how a simple class with a name and a real vector should be * written: * * <pre><code> * public class NamedVector implements Serializable { * * private final String name; * private final transient RealVector coefficients; * * // omitted constructors, getters ... * * private void writeObject(ObjectOutputStream oos) throws IOException { * oos.defaultWriteObject(); // takes care of name field * MatrixUtils.serializeRealVector(coefficients, oos); * } * * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { * ois.defaultReadObject(); // takes care of name field * MatrixUtils.deserializeRealVector(this, "coefficients", ois); * } * * } * </code></pre> * * @param vector real vector to serialize * @param oos stream where the real vector should be written * @exception IOException if object cannot be written to stream * @see #deserializeRealVector(Object, String, ObjectInputStream) */ public static void serializeRealVector(final RealVector vector, final ObjectOutputStream oos) throws IOException { final int n = vector.getDimension(); oos.writeInt(n); for (int i = 0; i < n; ++i) { oos.writeDouble(vector.getEntry(i)); } }
/** * Returns an estimate of the solution to the linear system A · x = b. * * @param a the linear operator A of the system * @param b the right-hand side vector * @return a new vector containing the solution * @throws NullArgumentException if one of the parameters is {@code null} * @throws NonSquareOperatorException if {@code a} is not square * @throws DimensionMismatchException if {@code b} has dimensions inconsistent with {@code a} * @throws MaxCountExceededException at exhaustion of the iteration count, unless a custom {@link * org.apache.commons.math3.util.Incrementor.MaxCountExceededCallback callback} has been set * at construction of the {@link IterationManager} */ public RealVector solve(final RealLinearOperator a, final RealVector b) throws NullArgumentException, NonSquareOperatorException, DimensionMismatchException, MaxCountExceededException { MathUtils.checkNotNull(a); final RealVector x = new ArrayRealVector(a.getColumnDimension()); x.set(0.); return solveInPlace(a, b, x); }
private static int toColour(RealVector ambient) { final double r = ambient.getEntry(0); final double g = ambient.getEntry(1); final double b = ambient.getEntry(2); int rc = 256 * 256 * (int) (255. * r); int rg = 256 * (int) (255. * g); int rb = (int) (255. * b); return rc + rg + rb; }
/** {@inheritDoc} */ public void setRowVector(final int row, final RealVector vector) { MatrixUtils.checkRowIndex(this, row); final int nCols = getColumnDimension(); if (vector.getDimension() != nCols) { throw new MatrixDimensionMismatchException(1, vector.getDimension(), 1, nCols); } for (int i = 0; i < nCols; ++i) { setEntry(row, i, vector.getEntry(i)); } }
/** {@inheritDoc} */ public OpenMapRealVector ebeMultiply(RealVector v) throws IllegalArgumentException { checkVectorDimensions(v.getDimension()); OpenMapRealVector res = new OpenMapRealVector(this); Iterator iter = res.entries.iterator(); while (iter.hasNext()) { iter.advance(); res.setEntry(iter.key(), iter.value() * v.getEntry(iter.key())); } return res; }
/** {@inheritDoc} */ public void setColumnVector(final int column, final RealVector vector) { MatrixUtils.checkColumnIndex(this, column); final int nRows = getRowDimension(); if (vector.getDimension() != nRows) { throw new MatrixDimensionMismatchException(vector.getDimension(), 1, nRows, 1); } for (int i = 0; i < nRows; ++i) { setEntry(i, column, vector.getEntry(i)); } }
/** {@inheritDoc} */ public double dotProduct(RealVector v) throws IllegalArgumentException { checkVectorDimensions(v.getDimension()); double res = 0; Iterator iter = entries.iterator(); while (iter.hasNext()) { iter.advance(); res += v.getEntry(iter.key()) * iter.value(); } return res; }
/** * Generic copy constructor. * * @param v The instance to copy from */ public OpenMapRealVector(RealVector v) { virtualSize = v.getDimension(); entries = new OpenIntToDoubleHashMap(0.0); epsilon = DEFAULT_ZERO_TOLERANCE; for (int key = 0; key < virtualSize; key++) { double value = v.getEntry(key); if (!isZero(value)) { entries.put(key, value); } } }
@Override public Vector<Double> row(int i) { final RealVector v = new RealVector(alpha.length); if (i > 0) { v.put(i - 1, beta[i - 1]); } v.put(i, alpha[i]); if (i + 1 < alpha.length) { v.put(i + 1, beta[i]); } return v; }
/** * @return A matrix containing the row stochastic values of the matrix that contains the * information about the item categorization, to be used by a {@code HIRItemScorer}. */ public RealMatrix RowStochastic() { for (int i = 0; i < itemSize; i++) { RealVector forIter = rowStochastic.getRowVector(i); double sum = forIter.getL1Norm(); if (sum != 0) { forIter.mapDivideToSelf(sum); rowStochastic.setRowVector(i, forIter); } } return rowStochastic; }
@Test public void testGenreVector() { double[] testVec1 = {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; double[] testVec2 = {0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; RealVector testRealVector1 = MatrixUtils.createRealVector(testVec1); RealVector testRealVector2 = MatrixUtils.createRealVector(testVec2); assertThat(gdao.getItemGenre(0), equalTo(testRealVector1)); assertThat(gdao.getItemGenre(5), equalTo(testRealVector2)); assertThat(testRealVector1.getDimension(), equalTo(gdao.getGenreSize())); assertThat(gdao.getItemGenre(0).getDimension(), equalTo(gdao.getGenreSize())); assertThat(testVec1.length, equalTo(gdao.getGenreSize())); }
@Override public void train(List<Instance> instances) { // ------------------------ initialize rows and columns --------------------- int rows = instances.size(); int columns = 0; // get max columns for (Instance i : instances) { int localColumns = Collections.max(i.getFeatureVector().getFeatureMap().keySet()); if (localColumns > columns) columns = localColumns; } // ------------------------ initialize alpha vector ----------------------- alpha = new ArrayRealVector(rows, 0); // ------------------------ initialize base X and Y for use -------------------------- double[][] X = new double[rows][columns]; double[] Y = new double[rows]; for (int i = 0; i < rows; i++) { Y[i] = ((ClassificationLabel) instances.get(i).getLabel()).getLabelValue(); for (int j = 0; j < columns; j++) { X[i][j] = instances.get(i).getFeatureVector().get(j + 1); } } // ---------------------- gram matrix ------------------- matrixX = new Array2DRowRealMatrix(X); RealMatrix gram = new Array2DRowRealMatrix(rows, rows); for (int i = 0; i < rows; i++) { for (int j = 0; j < rows; j++) { gram.setEntry(i, j, kernelFunction(matrixX.getRowVector(i), matrixX.getRowVector(j))); } } // ---------------------- gradient ascent -------------------------- Sigmoid g = new Sigmoid(); // helper function System.out.println("Training start..."); System.out.println( "Learning rate: " + _learning_rate + " Training times: " + _training_iterations); for (int idx = 0; idx < _training_iterations; idx++) { System.out.println("Training iteration: " + (idx + 1)); for (int k = 0; k < rows; k++) { double gradient_ascent = 0.0; RealVector alpha_gram = gram.operate(alpha); for (int i = 0; i < rows; i++) { double lambda = alpha_gram.getEntry(i); double kernel = gram.getEntry(i, k); gradient_ascent = gradient_ascent + Y[i] * g.value(-lambda) * kernel + (1 - Y[i]) * g.value(lambda) * (-kernel); } alpha.setEntry(k, alpha.getEntry(k) + _learning_rate * gradient_ascent); } } System.out.println("Training done!"); }
public static void main(String[] args) { RealMatrix coefficients2 = new Array2DRowRealMatrix( new double[][] { {0.0D, 1.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.857D, 0.0D, 0.054D, 0.018D, 0.0D, 0.071D, 0.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.0D, 1.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.857D, 0.0D, 0.054D, 0.018D, 0.0D, 0.071D, 0.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 1.0D, 0.0D, 0.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 1.0D, 0.0D, 0.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 1.0D, 0.0D, 0.0D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.6D, 0.4D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 1.0D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 1.0D, 0.0D, 0.0D, 1.0D}, {0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D, 0.0D} }, false); for (int i = 0; i < 11; i++) { coefficients2.setEntry(i, i, -1d); } coefficients2 = coefficients2.transpose(); DecompositionSolver solver = new LUDecompositionImpl(coefficients2).getSolver(); System.out.println("1 method my Value :"); RealVector constants = new ArrayRealVector(new double[] {-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false); RealVector solution = solver.solve(constants); double[] data = solution.getData(); DecimalFormat df = new DecimalFormat(); df.setRoundingMode(RoundingMode.DOWN); System.out.println("Корни уравнения:"); for (double dd : data) { System.out.print(df.format(dd) + " "); } System.out.println(); System.out.println( "Среднее число процессорных операций, выполняемых при одном прогоне алгоритма: " + operationsByProcess(data, arr)); System.out.println("Среднее число обращений к файлам:"); for (int i = 1; i < 4; i++) { System.out.println(" Файл " + i + " : " + fileMiddleRequest(data, arr, i)); } System.out.println("Среднее количество информации передаваемой при одном обращении к файлам:"); for (int i = 1; i < 4; i++) { System.out.println(" Файл " + i + " : " + bitsPerFileTransfer(data, arr, i)); } System.out.println( "Сумма среднего числа обращений к основным операторам: " + operatorExecute(data, arr)); System.out.println("Средняя трудоемкость этапа: " + middleWork(data, arr)); }
/** * Performs all dimension checks on the parameters of {@link #solve(RealLinearOperator, * RealVector, RealVector) solve} and {@link #solveInPlace(RealLinearOperator, RealVector, * RealVector) solveInPlace}, and throws an exception if one of the checks fails. * * @param a the linear operator A of the system * @param b the right-hand side vector * @param x0 the initial guess of the solution * @throws NullArgumentException if one of the parameters is {@code null} * @throws NonSquareOperatorException if {@code a} is not square * @throws DimensionMismatchException if {@code b} or {@code x0} have dimensions inconsistent with * {@code a} */ protected static void checkParameters( final RealLinearOperator a, final RealVector b, final RealVector x0) throws NullArgumentException, NonSquareOperatorException, DimensionMismatchException { MathUtils.checkNotNull(a); MathUtils.checkNotNull(b); MathUtils.checkNotNull(x0); if (a.getRowDimension() != a.getColumnDimension()) { throw new NonSquareOperatorException(a.getRowDimension(), a.getColumnDimension()); } if (b.getDimension() != a.getRowDimension()) { throw new DimensionMismatchException(b.getDimension(), a.getRowDimension()); } if (x0.getDimension() != a.getColumnDimension()) { throw new DimensionMismatchException(x0.getDimension(), a.getColumnDimension()); } }
/** {@inheritDoc} */ public RealMatrix outerProduct(RealVector v) throws IllegalArgumentException { checkVectorDimensions(v.getDimension()); if (v instanceof OpenMapRealVector) { return outerproduct((OpenMapRealVector) v); } RealMatrix res = new OpenMapRealMatrix(virtualSize, virtualSize); Iterator iter = entries.iterator(); while (iter.hasNext()) { iter.advance(); int row = iter.key(); for (int col = 0; col < virtualSize; col++) { res.setEntry(row, col, iter.value() * v.getEntry(col)); } } return res; }
private static Intersection getIntersection( TriangleObject obj, DecompositionSolver solver, Vector3D p, Vector3D pc, double[] params) { RealVector constants = new ArrayRealVector(params, false); RealVector solution = solver.solve(constants); double alpfa = solution.getEntry(0); double beta = solution.getEntry(1); boolean match = false; if (alpfa >= 0 && beta >= 0 && alpfa + beta <= 1.0) { match = true; } else { match = false; } final Intersection intersection = new Intersection(match); intersection.setObject(obj); intersection.setP(p); return intersection; }
@Override public double dotProduct(RealVector v) { double dot = 0; for (int i = 0; i < data.length; i++) { dot += data[i] * v.getEntry(i); } return dot; }
@Test public void testConstructors() { OpenMapRealVector v0 = new OpenMapRealVector(); Assert.assertEquals("testData len", 0, v0.getDimension()); OpenMapRealVector v1 = new OpenMapRealVector(7); Assert.assertEquals("testData len", 7, v1.getDimension()); Assert.assertEquals("testData is 0.0 ", 0.0, v1.getEntry(6), 0); OpenMapRealVector v3 = new OpenMapRealVector(vec1); Assert.assertEquals("testData len", 3, v3.getDimension()); Assert.assertEquals("testData is 2.0 ", 2.0, v3.getEntry(1), 0); // SparseRealVector v4 = new SparseRealVector(vec4, 3, 2); // Assert.assertEquals("testData len", 2, v4.getDimension()); // Assert.assertEquals("testData is 4.0 ", 4.0, v4.getEntry(0)); // try { // new SparseRealVector(vec4, 8, 3); // Assert.fail("MathIllegalArgumentException expected"); // } catch (MathIllegalArgumentException ex) { // expected behavior // } RealVector v5_i = new OpenMapRealVector(dvec1); Assert.assertEquals("testData len", 9, v5_i.getDimension()); Assert.assertEquals("testData is 9.0 ", 9.0, v5_i.getEntry(8), 0); OpenMapRealVector v5 = new OpenMapRealVector(dvec1); Assert.assertEquals("testData len", 9, v5.getDimension()); Assert.assertEquals("testData is 9.0 ", 9.0, v5.getEntry(8), 0); OpenMapRealVector v7 = new OpenMapRealVector(v1); Assert.assertEquals("testData len", 7, v7.getDimension()); Assert.assertEquals("testData is 0.0 ", 0.0, v7.getEntry(6), 0); SparseRealVectorTestImpl v7_i = new SparseRealVectorTestImpl(vec1); OpenMapRealVector v7_2 = new OpenMapRealVector(v7_i); Assert.assertEquals("testData len", 3, v7_2.getDimension()); Assert.assertEquals("testData is 0.0 ", 2.0d, v7_2.getEntry(1), 0); OpenMapRealVector v8 = new OpenMapRealVector(v1); Assert.assertEquals("testData len", 7, v8.getDimension()); Assert.assertEquals("testData is 0.0 ", 0.0, v8.getEntry(6), 0); }
/** * Solve the linear equation A × X = B in least square sense. * * <p>The m×n matrix A may not be square, the solution X is such that ||A × X - B|| * is minimal. * * @param b right-hand side of the equation A × X = B * @return a vector X that minimizes the two norm of A × X - B * @exception IllegalArgumentException if matrices dimensions don't match * @exception InvalidMatrixException if decomposed matrix is singular */ public RealVector solve(final RealVector b) throws IllegalArgumentException, InvalidMatrixException { if (b.getDimension() != uT.getColumnDimension()) { throw MathRuntimeException.createIllegalArgumentException( "vector length mismatch: got {0} but expected {1}", b.getDimension(), uT.getColumnDimension()); } final RealVector w = uT.operate(b); for (int i = 0; i < singularValues.length; ++i) { final double si = singularValues[i]; if (si == 0) { throw new SingularMatrixException(); } w.setEntry(i, w.getEntry(i) / si); } return v.operate(w); }
/* Check that the operations do not throw an exception (cf. MATH-645). */ @Test public void testConcurrentModification() { final RealVector u = new OpenMapRealVector(3, 1e-6); u.setEntry(0, 1); u.setEntry(1, 0); u.setEntry(2, 2); final RealVector v1 = new OpenMapRealVector(3, 1e-6); v1.setEntry(0, 0); v1.setEntry(1, 3); v1.setEntry(2, 0); u.ebeMultiply(v1); u.ebeDivide(v1); }
/** {@inheritDoc} */ public RealVector preMultiply(final RealVector v) throws DimensionMismatchException { try { return new ArrayRealVector(preMultiply(((ArrayRealVector) v).getDataRef()), false); } catch (ClassCastException cce) { final int nRows = getRowDimension(); final int nCols = getColumnDimension(); if (v.getDimension() != nRows) { throw new DimensionMismatchException(v.getDimension(), nRows); } final double[] out = new double[nCols]; for (int col = 0; col < nCols; ++col) { double sum = 0; for (int i = 0; i < nRows; ++i) { sum += getEntry(i, col) * v.getEntry(i); } out[col] = sum; } return new ArrayRealVector(out, false); } }
/** {@inheritDoc} */ @Override public RealVector operate(final RealVector v) throws DimensionMismatchException { try { return new ArrayRealVector(operate(((ArrayRealVector) v).getDataRef()), false); } catch (ClassCastException cce) { final int nRows = getRowDimension(); final int nCols = getColumnDimension(); if (v.getDimension() != nCols) { throw new DimensionMismatchException(v.getDimension(), nCols); } final double[] out = new double[nRows]; for (int row = 0; row < nRows; ++row) { double sum = 0; for (int i = 0; i < nCols; ++i) { sum += getEntry(row, i) * v.getEntry(i); } out[row] = sum; } return new ArrayRealVector(out, false); } }
private void stochasticUpdateStep(Pair<Integer, Set<Integer>> wordPlusContexts, int s) { double eta = learningRateDecay(s); int wordIndex = wordPlusContexts.getFirst(); // actual center word // Set h vector equal to the kth row of weight matrix W1. h = x' * W = W[k,:] = v(input) RealVector h = W1.getRowVector(wordIndex); // 1xN row vector for (int contextWordIndex : wordPlusContexts.getSecond()) { Set<Integer> negativeContexts; if (sampleUnigram) { negativeContexts = negativeSampleContexts(wordIndex, noiseSampler); } else { negativeContexts = negativeSampleContexts(wordIndex); } // wordIndex is the input word // negativeContexts is the k negative contexts // contextWordIndex is 1 positive context // First update the output vectors for 1 positive context RealVector vPrime_j = W2.getColumnVector(contextWordIndex); // Nx1 column vector double u = h.dotProduct(vPrime_j); // u_j = vPrime(output) * v(input) double t_j = 1.0; // t_j := 1{j == contextWordIndex} double scale = sigmoid(u) - t_j; scale = eta * scale; RealVector gradientOut2Hidden = h.mapMultiply(scale); vPrime_j = vPrime_j.subtract(gradientOut2Hidden); W2.setColumnVector(contextWordIndex, vPrime_j); // Next backpropagate the error to the hidden layer and update the input vectors RealVector v_I = h; u = h.dotProduct(vPrime_j); scale = sigmoid(u) - t_j; scale = eta * scale; RealVector gradientHidden2In = vPrime_j.mapMultiply(scale); v_I = v_I.subtract(gradientHidden2In); h = v_I; W1.setRowVector(wordIndex, v_I); // Repeat update process for k negative contexts t_j = 0.0; // t_j := 1{j == contextWordIndex} for (int negContext : negativeContexts) { vPrime_j = W2.getColumnVector(negContext); u = h.dotProduct(vPrime_j); scale = sigmoid(u) - t_j; scale = eta * scale; gradientOut2Hidden = h.mapMultiply(scale); vPrime_j = vPrime_j.subtract(gradientOut2Hidden); W2.setColumnVector(negContext, vPrime_j); // Backpropagate the error to the hidden layer and update the input vectors v_I = h; u = h.dotProduct(vPrime_j); scale = sigmoid(u) - t_j; scale = eta * scale; gradientHidden2In = vPrime_j.mapMultiply(scale); v_I = v_I.subtract(gradientHidden2In); h = v_I; W1.setRowVector(wordIndex, v_I); } } }
private static double sigmoid(RealVector x, RealVector y) { double z = x.dotProduct(y); return sigmoid(z); }