@Override public void transform(Vector source, Vector dest) { int rc = rowCount(); int cc = rc; if (source.length() != cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source)); if (dest.length() != rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest)); double[] sdata = source.getArray(); double[] ddata = dest.getArray(); for (int row = 0; row < rc; row++) { ddata[row] = sdata[row] * unsafeGetDiagonalValue(row); } }
@Override public double get(int x, int y) { if (dimensions == 2) { return data[offset + x * getStride(0) + y * getStride(1)]; } else { throw new UnsupportedOperationException(ErrorMessages.invalidIndex(this, x, y)); } }
@Override public double get() { if (dimensions == 0) { return data[offset]; } else { throw new UnsupportedOperationException(ErrorMessages.invalidIndex(this)); } }
@Override public int sliceCount() { if (dimensions == 0) { throw new IllegalArgumentException(ErrorMessages.noSlices(this)); } else { return getShape(0); } }
@Override public AVector getBand(int band) { if (band == 0) { return getLeadingDiagonal(); } else { if ((band > dimensions) || (band < -dimensions)) throw new IndexOutOfBoundsException(ErrorMessages.invalidBand(this, band)); return Vectorz.createZeroVector(bandLength(band)); } }
public AMatrix innerProduct(ADiagonalMatrix a) { int dims = this.dimensions; if (dims != a.dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a)); DiagonalMatrix result = DiagonalMatrix.createDimensions(dims); for (int i = 0; i < dims; i++) { result.data[i] = unsafeGetDiagonalValue(i) * a.unsafeGetDiagonalValue(i); } return result; }
@Override public INDArray slice(int dimension, int index) { if (dimension < 0) throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension)); if (dimension == 0) return slice(index); ArrayList<INDArray> al = new ArrayList<INDArray>(sliceCount()); for (INDArray s : this) { al.add(s.slice(dimension - 1, index)); } return SliceArray.create(al); }
@Override public void transformInPlace(AVector v) { if (v instanceof AArrayVector) { transformInPlace((AArrayVector) v); return; } if (v.length() != dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, v)); for (int i = 0; i < dimensions; i++) { v.unsafeSet(i, v.unsafeGet(i) * unsafeGetDiagonalValue(i)); } }
@Override public Matrix innerProduct(Matrix a) { if (!(dimensions == a.rowCount())) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a)); int acc = a.columnCount(); Matrix m = Matrix.create(dimensions, acc); for (int i = 0; i < dimensions; i++) { double dv = unsafeGetDiagonalValue(i); for (int j = 0; j < acc; j++) { m.unsafeSet(i, j, dv * a.unsafeGet(i, j)); } } return m; }
@Override public Matrix transposeInnerProduct(Matrix s) { if (s.rowCount() != 1) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, s)); int rc = this.columnCount(); int cc = s.columnCount(); Matrix m = Matrix.create(rc, cc); for (int i = 0; i < rc; i++) { double ti = this.get(i); for (int j = 0; j < cc; j++) { m.unsafeSet(i, j, ti * s.unsafeGet(0, j)); } } return m; }
@Override public void unsafeSet(int i, double value) { throw new UnsupportedOperationException(ErrorMessages.immutable(this)); }
@Override public double get(int i) { if (i < 0 || (i >= length)) throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, i)); return 0.0; }
@Override public void set(int row, int column, double value) { throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this, row, column)); }