예제 #1
0
  @Override
  public void copyToHost(int offset, int length) {
    DevicePointerInfo devicePointerInfo =
        pointersToContexts.get(Thread.currentThread().getName(), new Pair<>(offset, length));
    if (devicePointerInfo == null)
      throw new IllegalStateException("No pointer found for offset " + offset);
    // prevent inconsistent pointers
    if (devicePointerInfo.getOffset() != offset)
      throw new IllegalStateException(
          "Device pointer offset didn't match specified offset in pointer map");

    if (devicePointerInfo != null) {
      ContextHolder.syncStream();
      int deviceStride = devicePointerInfo.getStride();
      int deviceOffset = devicePointerInfo.getOffset();
      long deviceLength = devicePointerInfo.getLength();
      if (deviceOffset == 0 && length < length()) {
        /**
         * The way the data works out the stride for retrieving the data should be 1.
         *
         * <p>The device stride should be used for resetting the data.
         *
         * <p>This is for the edge case where the offset is zero and the length of the pointer is <
         * the actual buffer length itself.
         */
        JCublas2.cublasGetVectorAsync(
            length,
            getElementSize(),
            devicePointerInfo.getPointer().withByteOffset(offset * getElementSize()),
            deviceStride,
            getHostPointer(deviceOffset),
            deviceStride,
            ContextHolder.getInstance().getCudaStream());
      } else {
        JCublas2.cublasGetVectorAsync(
            (int) deviceLength,
            getElementSize(),
            devicePointerInfo.getPointer().withByteOffset(offset * getElementSize()),
            deviceStride,
            getHostPointer(deviceOffset),
            deviceStride,
            ContextHolder.getInstance().getCudaStream());
      }

      ContextHolder.syncStream();

    } else throw new IllegalStateException("No offset found to copy");
  }
예제 #2
0
  /**
   * Set an individual element
   *
   * @param index the index of the element
   * @param from the element to get data from
   */
  protected void set(int index, int length, Pointer from, int inc) {

    modified.set(true);

    int offset = getElementSize() * index;
    if (offset >= length() * getElementSize())
      throw new IllegalArgumentException(
          "Illegal offset " + offset + " with index of " + index + " and length " + length());

    JCublas2.cublasSetVectorAsync(
        length,
        getElementSize(),
        from,
        inc,
        getHostPointer().withByteOffset(offset),
        1,
        ContextHolder.getInstance().getCudaStream());

    ContextHolder.syncStream();
  }