/** @author [email protected] */
public class ProtectedCudaShapeInfoProvider extends BaseShapeInfoProvider {

  private AtomicAllocator allocator;

  private AtomicLong cacheHit = new AtomicLong(1);
  private AtomicLong cacheMiss = new AtomicLong(1);

  private Semaphore lock = new Semaphore(1);

  private Configuration configuration = CudaEnvironment.getInstance().getConfiguration();

  private static final ConstantProtector protector = ConstantProtector.getInstance();

  private static ProtectedCudaShapeInfoProvider ourInstance = new ProtectedCudaShapeInfoProvider();

  private ProtectedCudaShapeInfoProvider() {}

  public static ProtectedCudaShapeInfoProvider getInstance() {
    return ourInstance;
  }

  @Override
  public DataBuffer createShapeInformation(
      int[] shape, int[] stride, int offset, int elementWiseStride, char order) {
    offset = 0;

    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();

    ShapeDescriptor descriptor =
        new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);

    if (!protector.containsDataBuffer(deviceId, descriptor)) {
      //            logger.info("Cache miss");
      DataBuffer buffer =
          super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
      buffer.setConstant(true);

      if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
        Nd4j.getConstantHandler().moveToConstantSpace(buffer);
      }

      // deviceCache.get(deviceId).put(descriptor, buffer);
      protector.persistDataBuffer(deviceId, descriptor, buffer);

      cacheMiss.incrementAndGet();
      return buffer;
    } else {
      // logger.info("Cache hit");
      cacheHit.incrementAndGet();
    }

    return protector.getDataBuffer(
        deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor);
  }
}
  @Before
  public void setUp() throws Exception {
    if (environment == null) {
      environment = new CudaEnvironment(configuration);

      DeviceInformation device1 = new DeviceInformation();
      device1.setDeviceId(0);
      device1.setCcMajor(5);
      device1.setCcMinor(2);
      device1.setTotalMemory(4 * 1024 * 1024 * 1024L);
      device1.setAvailableMemory(4 * 1024 * 1024 * 1024L);

      DeviceInformation device2 = new DeviceInformation();
      device2.setDeviceId(1);
      device2.setCcMajor(5);
      device2.setCcMinor(2);
      device2.setTotalMemory(4 * 1024 * 1024 * 1024L);
      device2.setAvailableMemory(4 * 1024 * 1024 * 1024L);

      environment.addDevice(device1);
      environment.addDevice(device2);
    }
  }