/** @author [email protected] */ public class ProtectedCudaShapeInfoProvider extends BaseShapeInfoProvider { private AtomicAllocator allocator; private AtomicLong cacheHit = new AtomicLong(1); private AtomicLong cacheMiss = new AtomicLong(1); private Semaphore lock = new Semaphore(1); private Configuration configuration = CudaEnvironment.getInstance().getConfiguration(); private static final ConstantProtector protector = ConstantProtector.getInstance(); private static ProtectedCudaShapeInfoProvider ourInstance = new ProtectedCudaShapeInfoProvider(); private ProtectedCudaShapeInfoProvider() {} public static ProtectedCudaShapeInfoProvider getInstance() { return ourInstance; } @Override public DataBuffer createShapeInformation( int[] shape, int[] stride, int offset, int elementWiseStride, char order) { offset = 0; Integer deviceId = AtomicAllocator.getInstance().getDeviceId(); ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order); if (!protector.containsDataBuffer(deviceId, descriptor)) { // logger.info("Cache miss"); DataBuffer buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order); buffer.setConstant(true); if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) { Nd4j.getConstantHandler().moveToConstantSpace(buffer); } // deviceCache.get(deviceId).put(descriptor, buffer); protector.persistDataBuffer(deviceId, descriptor, buffer); cacheMiss.incrementAndGet(); return buffer; } else { // logger.info("Cache hit"); cacheHit.incrementAndGet(); } return protector.getDataBuffer( deviceId, descriptor); // deviceCache.get(deviceId).get(descriptor); } }
@Before public void setUp() throws Exception { if (environment == null) { environment = new CudaEnvironment(configuration); DeviceInformation device1 = new DeviceInformation(); device1.setDeviceId(0); device1.setCcMajor(5); device1.setCcMinor(2); device1.setTotalMemory(4 * 1024 * 1024 * 1024L); device1.setAvailableMemory(4 * 1024 * 1024 * 1024L); DeviceInformation device2 = new DeviceInformation(); device2.setDeviceId(1); device2.setCcMajor(5); device2.setCcMinor(2); device2.setTotalMemory(4 * 1024 * 1024 * 1024L); device2.setAvailableMemory(4 * 1024 * 1024 * 1024L); environment.addDevice(device1); environment.addDevice(device2); } }