protected void init_spaces() { userspace = new float[maxuid + 1][]; itemspace = new float[maxiid + 1][]; float frag = (float) Math.sqrt((avgrating * 0.1f / factor)); // pu * qi is about 10% effect of r^ui for (int u = 0; u < userspace.length; u++) { userspace[u] = new float[factor]; for (int f = 0; f < factor; f++) { userspace[u][f] = (float) (frag * (Utilities.randomDouble())); } } System.out.println(itemspace.length); for (int i = 0; i < itemspace.length; i++) { itemspace[i] = new float[factor]; for (int f = 0; f < factor; f++) { itemspace[i][f] = (float) (frag * (Utilities.randomDouble())); } } userbias = new float[maxuid + 1]; itembias = new float[maxiid + 1]; for (int u = 0; u < userbias.length; u++) { userbias[u] = (float) Utilities.randomDouble(-0.5, 0.5); // } for (int i = 0; i < itembias.length; i++) { itembias[i] = (float) Utilities.randomDouble(-1.5, 1.5); } }
private void _train() throws Exception { System.out.println("--------------------"); float learningSpeed = this.alpha; for (int loop = 0; loop < this.loops; loop++) { dataEntry.reOpen(); double totalError = 0; int n = 0; long timeStart = System.currentTimeMillis(); // core computation for (Vector v = dataEntry.getNextVector(); v != null; v = dataEntry.getNextVector()) { UserRatings ur = new UserRatings(v); for (RatingInfo ri = ur.getNormalNextRating(); ri != null; ri = ur.getNormalNextRating()) { float eui = ri.rating - (this.avgrating + userbias[ri.userId] + itembias[ri.itemId] + Utilities.innerProduct(userspace[ri.userId], itemspace[ri.itemId])); // perform gradient on user/item bias userbias[ri.userId] += learningSpeed * (eui - this.lambda * userbias[ri.userId]); itembias[ri.itemId] += learningSpeed * (eui - this.lambda * itembias[ri.itemId]); // perform gradient on pu/qi for (int f = 0; f < this.factor; f++) { userspace[ri.userId][f] = userspace[ri.userId][f] + learningSpeed * (eui * itemspace[ri.itemId][f] - this.lambda * userspace[ri.userId][f]); itemspace[ri.itemId][f] = itemspace[ri.itemId][f] + learningSpeed * (eui * userspace[ri.userId][f] - this.lambda * itemspace[ri.itemId][f]); } totalError += Math.abs(eui); n += 1; } } long timeSpent = System.currentTimeMillis() - timeStart; learningSpeed *= this.convergence; System.out.println( String.format( "loop:%d\ttime(ms):%d\tavgerror:%.6f\tnext alpha:%.5f", loop, timeSpent, (totalError / n), learningSpeed)); // System.out.print("loop " + loop + " finished~ Time spent: " + (timeSpent / 1000.0) + " // next speed :" + learningSpeed); // System.out.println(" total training ratings = " + n); } dataEntry.close(); }
protected float predict(int userId, int itemId) { return this.avgrating + userbias[userId] + itembias[itemId] + Utilities.innerProduct(userspace[userId], itemspace[itemId]); }