@Override public DataBuffer createShapeInformation( int[] shape, int[] stride, int offset, int elementWiseStride, char order) { offset = 0; Integer deviceId = AtomicAllocator.getInstance().getDeviceId(); ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order); if (!protector.containsDataBuffer(deviceId, descriptor)) { // logger.info("Cache miss"); DataBuffer buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order); buffer.setConstant(true); if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) { Nd4j.getConstantHandler().moveToConstantSpace(buffer); } // deviceCache.get(deviceId).put(descriptor, buffer); protector.persistDataBuffer(deviceId, descriptor, buffer); cacheMiss.incrementAndGet(); return buffer; } else { // logger.info("Cache hit"); cacheHit.incrementAndGet(); } return protector.getDataBuffer( deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor); }
/** @author [email protected] */ public class ProtectedCudaShapeInfoProvider extends BaseShapeInfoProvider { private AtomicAllocator allocator; private AtomicLong cacheHit = new AtomicLong(1); private AtomicLong cacheMiss = new AtomicLong(1); private Semaphore lock = new Semaphore(1); private Configuration configuration = CudaEnvironment.getInstance().getConfiguration(); private static final ConstantProtector protector = ConstantProtector.getInstance(); private static ProtectedCudaShapeInfoProvider ourInstance = new ProtectedCudaShapeInfoProvider(); private ProtectedCudaShapeInfoProvider() {} public static ProtectedCudaShapeInfoProvider getInstance() { return ourInstance; } @Override public DataBuffer createShapeInformation( int[] shape, int[] stride, int offset, int elementWiseStride, char order) { offset = 0; Integer deviceId = AtomicAllocator.getInstance().getDeviceId(); ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order); if (!protector.containsDataBuffer(deviceId, descriptor)) { // logger.info("Cache miss"); DataBuffer buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order); buffer.setConstant(true); if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) { Nd4j.getConstantHandler().moveToConstantSpace(buffer); } // deviceCache.get(deviceId).put(descriptor, buffer); protector.persistDataBuffer(deviceId, descriptor, buffer); cacheMiss.incrementAndGet(); return buffer; } else { // logger.info("Cache hit"); cacheHit.incrementAndGet(); } return protector.getDataBuffer( deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor); } }