コード例 #1
0
  @Override
  public DataBuffer createShapeInformation(
      int[] shape, int[] stride, int offset, int elementWiseStride, char order) {
    offset = 0;

    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();

    ShapeDescriptor descriptor =
        new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);

    if (!protector.containsDataBuffer(deviceId, descriptor)) {
      //            logger.info("Cache miss");
      DataBuffer buffer =
          super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
      buffer.setConstant(true);

      if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
        Nd4j.getConstantHandler().moveToConstantSpace(buffer);
      }

      // deviceCache.get(deviceId).put(descriptor, buffer);
      protector.persistDataBuffer(deviceId, descriptor, buffer);

      cacheMiss.incrementAndGet();
      return buffer;
    } else {
      // logger.info("Cache hit");
      cacheHit.incrementAndGet();
    }

    return protector.getDataBuffer(
        deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor);
  }
コード例 #2
0
  @Override
  public DataBuffer create(DataBuffer underlyingBuffer, long offset, long length) {
    if (underlyingBuffer.dataType() == DataBuffer.Type.DOUBLE) {
      return new DoubleBuffer(underlyingBuffer, length, offset);
    } else if (underlyingBuffer.dataType() == DataBuffer.Type.FLOAT) {
      return new FloatBuffer(underlyingBuffer, length, offset);

    } else if (underlyingBuffer.dataType() == DataBuffer.Type.INT) {
      return new IntBuffer(underlyingBuffer, length, offset);
    }
    return null;
  }
コード例 #3
0
  @Test
  public void testGetReal() {
    DataBuffer data = Nd4j.linspace(1, 8, 8).data();
    int[] shape = new int[] {1, 8};
    IComplexNDArray arr = Nd4j.createComplex(shape);
    for (int i = 0; i < arr.length(); i++) arr.put(i, Nd4j.scalar(data.getFloat(i)));
    INDArray arr2 = Nd4j.create(data, shape);
    assertEquals(arr2, arr.getReal());

    INDArray ones = Nd4j.ones(10);
    IComplexNDArray n2 = Nd4j.complexOnes(10);
    assertEquals(getFailureMessage(), ones, n2.getReal());
  }
コード例 #4
0
ファイル: CpuLevel3.java プロジェクト: KillEdision/nd4j
  @Override
  protected void sgemm(
      char Order,
      char TransA,
      char TransB,
      int M,
      int N,
      int K,
      float alpha,
      INDArray A,
      int lda,
      INDArray B,
      int ldb,
      float beta,
      INDArray C,
      int ldc) {
    A = Shape.toOffsetZero(A);
    B = Shape.toOffsetZero(B);

    DataBuffer aData = A.data();
    DataBuffer bData = B.data();

    float[] cData = getFloatData(C);
    BLAS.getInstance()
        .sgemm(
            String.valueOf(TransA),
            String.valueOf(TransB),
            M,
            N,
            K,
            alpha,
            aData.asFloat(),
            getBlasOffset(A),
            lda,
            bData.asFloat(),
            getBlasOffset(B),
            ldb,
            beta,
            cData,
            getBlasOffset(C),
            ldc);
    setData(cData, C);
  }