@Override
  public DataBuffer createShapeInformation(
      int[] shape, int[] stride, int offset, int elementWiseStride, char order) {
    offset = 0;

    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();

    ShapeDescriptor descriptor =
        new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);

    if (!protector.containsDataBuffer(deviceId, descriptor)) {
      //            logger.info("Cache miss");
      DataBuffer buffer =
          super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
      buffer.setConstant(true);

      if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
        Nd4j.getConstantHandler().moveToConstantSpace(buffer);
      }

      // deviceCache.get(deviceId).put(descriptor, buffer);
      protector.persistDataBuffer(deviceId, descriptor, buffer);

      cacheMiss.incrementAndGet();
      return buffer;
    } else {
      // logger.info("Cache hit");
      cacheHit.incrementAndGet();
    }

    return protector.getDataBuffer(
        deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor);
  }
示例#2
0
  @Override
  public DataBuffer create(DataBuffer underlyingBuffer, long offset, long length) {
    if (underlyingBuffer.dataType() == DataBuffer.Type.DOUBLE) {
      return new DoubleBuffer(underlyingBuffer, length, offset);
    } else if (underlyingBuffer.dataType() == DataBuffer.Type.FLOAT) {
      return new FloatBuffer(underlyingBuffer, length, offset);

    } else if (underlyingBuffer.dataType() == DataBuffer.Type.INT) {
      return new IntBuffer(underlyingBuffer, length, offset);
    }
    return null;
  }
  @Test
  public void testGetReal() {
    DataBuffer data = Nd4j.linspace(1, 8, 8).data();
    int[] shape = new int[] {1, 8};
    IComplexNDArray arr = Nd4j.createComplex(shape);
    for (int i = 0; i < arr.length(); i++) arr.put(i, Nd4j.scalar(data.getFloat(i)));
    INDArray arr2 = Nd4j.create(data, shape);
    assertEquals(arr2, arr.getReal());

    INDArray ones = Nd4j.ones(10);
    IComplexNDArray n2 = Nd4j.complexOnes(10);
    assertEquals(getFailureMessage(), ones, n2.getReal());
  }
示例#4
0
  @Override
  protected void sgemm(
      char Order,
      char TransA,
      char TransB,
      int M,
      int N,
      int K,
      float alpha,
      INDArray A,
      int lda,
      INDArray B,
      int ldb,
      float beta,
      INDArray C,
      int ldc) {
    A = Shape.toOffsetZero(A);
    B = Shape.toOffsetZero(B);

    DataBuffer aData = A.data();
    DataBuffer bData = B.data();

    float[] cData = getFloatData(C);
    BLAS.getInstance()
        .sgemm(
            String.valueOf(TransA),
            String.valueOf(TransB),
            M,
            N,
            K,
            alpha,
            aData.asFloat(),
            getBlasOffset(A),
            lda,
            bData.asFloat(),
            getBlasOffset(B),
            ldb,
            beta,
            cData,
            getBlasOffset(C),
            ldc);
    setData(cData, C);
  }