public void cluster() { prepareOpenCl(); for (int iteration = 0; iteration < noIterations; iteration++) { Arrays.fill(clustersUpdates, 0); memUpdates.copyHtoD(); for (int batch = 0; batch < images.size() / batchItems; batch++) { for (int i = 0; i < batchItems; i++) { System.arraycopy( images.get(batch * batchItems + i).getDataShort(), 0, inputImages, i * imageSize, imageSize); } memImages.copyHtoD(); updateCenters.run(batchItems, 256); program.finish(); } System.out.println(iteration); memUpdates.copyDtoH(); reduceCenters(); subtractMeanClusters(); memClusters.copyHtoD(); } releaseOpenCl(); constructImageClusters(); }
public void releaseOpenCl() { memClusters.release(); memImages.release(); memUpdates.release(); updateCenters.release(); program.release(); }
public void constructNormalizedImageClusters() { clusterImages.clear(); for (int i = 0; i < noClusters; i++) { Image image = new ImageFloat(dimFilterX, dimFilterY); System.arraycopy( memClusters.getSrc(), i * dimFilterX * dimFilterY, image.getDataFloat(), 0, dimFilterX * dimFilterY); clusterImages.add(image); } }
private void prepareOpenCl() { int dimImageX = images.get(0).imageX; int dimPoolingX = dimFilterX + dimFilterX; if (dimPoolingX > dimImageX) dimPoolingX = dimImageX; int dimImageY = images.get(0).imageY; int dimPoolingY = dimFilterY + dimFilterY; if (dimPoolingY > dimImageY) dimPoolingY = dimImageY; Map<String, Object> params = new HashMap<>(); params.put("IMAGE_SIZE", imageSize); params.put("FILTER_SIZE", dimFilterX * dimFilterY); params.put("NO_CLUSTERS", noClusters); params.put("DIM_FILTER_X", dimFilterX); params.put("DIM_FILTER_Y", dimFilterY); params.put("DIM_POOLING_X", dimPoolingX); params.put("DIM_POOLING_Y", dimPoolingY); params.put("DIM_IMAGE_X", dimImageX); params.put("DIM_IMAGE_Y", dimImageY); params.put("STRIDE_X", strideX); params.put("STRIDE_POOLING_X", stridePoolingX); params.put("STRIDE_Y", strideY); params.put("STRIDE_POOLING_Y", stridePoolingY); program = new Program(Program.readResource("/dot/SubImageKmeansDotProductShort.c"), params); memClusters = new MemoryFloat(program); memClusters.addReadWrite(clustersCenters); memImages = new MemoryShort(program); memImages.addReadOnly(inputImages); memUpdates = new MemoryFloat(program); memUpdates.addReadWrite(clustersUpdates); updateCenters = new Kernel(program, "updateCenters"); updateCenters.setArgument(memClusters, 0); updateCenters.setArgument(memImages, 1); updateCenters.setArgument(memUpdates, 2); }