@Override public DataBuffer createShapeInformation( int[] shape, int[] stride, int offset, int elementWiseStride, char order) { offset = 0; Integer deviceId = AtomicAllocator.getInstance().getDeviceId(); ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order); if (!protector.containsDataBuffer(deviceId, descriptor)) { // logger.info("Cache miss"); DataBuffer buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order); buffer.setConstant(true); if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) { Nd4j.getConstantHandler().moveToConstantSpace(buffer); } // deviceCache.get(deviceId).put(descriptor, buffer); protector.persistDataBuffer(deviceId, descriptor, buffer); cacheMiss.incrementAndGet(); return buffer; } else { // logger.info("Cache hit"); cacheHit.incrementAndGet(); } return protector.getDataBuffer( deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor); }
@Override public DataBuffer create(DataBuffer underlyingBuffer, long offset, long length) { if (underlyingBuffer.dataType() == DataBuffer.Type.DOUBLE) { return new DoubleBuffer(underlyingBuffer, length, offset); } else if (underlyingBuffer.dataType() == DataBuffer.Type.FLOAT) { return new FloatBuffer(underlyingBuffer, length, offset); } else if (underlyingBuffer.dataType() == DataBuffer.Type.INT) { return new IntBuffer(underlyingBuffer, length, offset); } return null; }
@Test public void testGetReal() { DataBuffer data = Nd4j.linspace(1, 8, 8).data(); int[] shape = new int[] {1, 8}; IComplexNDArray arr = Nd4j.createComplex(shape); for (int i = 0; i < arr.length(); i++) arr.put(i, Nd4j.scalar(data.getFloat(i))); INDArray arr2 = Nd4j.create(data, shape); assertEquals(arr2, arr.getReal()); INDArray ones = Nd4j.ones(10); IComplexNDArray n2 = Nd4j.complexOnes(10); assertEquals(getFailureMessage(), ones, n2.getReal()); }
@Override protected void sgemm( char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) { A = Shape.toOffsetZero(A); B = Shape.toOffsetZero(B); DataBuffer aData = A.data(); DataBuffer bData = B.data(); float[] cData = getFloatData(C); BLAS.getInstance() .sgemm( String.valueOf(TransA), String.valueOf(TransB), M, N, K, alpha, aData.asFloat(), getBlasOffset(A), lda, bData.asFloat(), getBlasOffset(B), ldb, beta, cData, getBlasOffset(C), ldc); setData(cData, C); }