/** * Matrix factorization for recommender systems. * * @author Shi-Huai Wen Email: shihuaiwen@outlook.com. */ publicclassMatrixFactorization { /** * Used to generate random numbers. */ Randomrand=newRandom();
/** * Number of users. */ int numUsers;
/** * Number of items. */ int numItems;
/** * Number of ratings. */ int numRatings;
/** * Training data. */ Triple[] dataset;
/** * A parameter for controlling learning regular. */ double alpha;
/** * A parameter for controlling the learning speed. */ double lambda;
/** * The low rank of the small matrices. */ int rank;
/** * The user matrix U. */ double[][] userSubspace;
/** * The item matrix V. */ double[][] itemSubspace;
/** * The lower bound of the rating value. */ double ratingLowerBound;
/** * The upper bound of the rating value. */ double ratingUpperBound;
/** * *********************** * The first constructor. * * @param paraFilename The data filename. * @param paraNumUsers The number of users. * @param paraNumItems The number of items. * @param paraNumRatings The number of ratings. * *********************** */ publicMatrixFactorization(String paraFilename, int paraNumUsers, int paraNumItems, int paraNumRatings, double paraRatingLowerBound, double paraRatingUpperBound) { numUsers = paraNumUsers; numItems = paraNumItems; numRatings = paraNumRatings; ratingLowerBound = paraRatingLowerBound; ratingUpperBound = paraRatingUpperBound;
try { readData(paraFilename, paraNumUsers, paraNumItems, paraNumRatings); } catch (Exception ee) { System.out.println("File " + paraFilename + " cannot be read! " + ee); System.exit(0); } // Of try }// Of the first constructor
/** * *********************** * Set parameters. * * @param paraRank The given rank. * *********************** */ publicvoidsetParameters(int paraRank, double paraAlpha, double paraLambda) { rank = paraRank; alpha = paraAlpha; lambda = paraLambda; }// Of setParameters
/** * *********************** * Read the data from the file. * * @param paraFilename The given file. * @throws IOException * *********************** */ publicvoidreadData(String paraFilename, int paraNumUsers, int paraNumItems, int paraNumRatings)throws IOException { FiletempFile=newFile(paraFilename); if (!tempFile.exists()) { System.out.println("File " + paraFilename + " does not exists."); System.exit(0); } // Of if BufferedReadertempBufferReader=newBufferedReader(newFileReader(tempFile));
// Allocate space. dataset = newTriple[paraNumRatings]; String tempString; String[] tempStringArray; for (inti=0; i < paraNumRatings; i++) { tempString = tempBufferReader.readLine(); tempStringArray = tempString.split(","); dataset[i] = newTriple(Integer.parseInt(tempStringArray[0]), Integer.parseInt(tempStringArray[1]), Double.parseDouble(tempStringArray[2])); } // Of for i
tempBufferReader.close(); }// Of readData
/** * *********************** * Initialize subspaces. Each value is in [0, 1]. * *********************** */ voidinitializeSubspaces() { userSubspace = newdouble[numUsers][rank];
for (inti=0; i < numUsers; i++) { for (intj=0; j < rank; j++) { userSubspace[i][j] = rand.nextDouble(); } // Of for j } // Of for i
itemSubspace = newdouble[numItems][rank]; for (inti=0; i < numItems; i++) { for (intj=0; j < rank; j++) { itemSubspace[i][j] = rand.nextDouble(); } // Of for j } // Of for i }// Of initializeSubspaces
/** * *********************** * Predict the rating of the user to the item * * @param paraUser The user index. * *********************** */ publicdoublepredict(int paraUser, int paraItem) { doubleresultValue=0; for (inti=0; i < rank; i++) { // The row vector of a user and the column vector of an item resultValue += userSubspace[paraUser][i] * itemSubspace[paraItem][i]; } // Of for i return resultValue; }// Of predict
/** * *********************** * Train. * * @param paraRounds The number of rounds. * *********************** */ publicvoidtrain(int paraRounds) { initializeSubspaces();
for (inti=0; i < paraRounds; i++) { updateNoRegular(); if (i % 500 == 0) { // Show the process System.out.println("Round " + i); System.out.println("MAE: " + mae()); System.out.println("RSME: " + rsme()); } // Of if } // Of for i }// Of train
/** * *********************** * Update sub-spaces using the training data. * *********************** */ publicvoidupdateNoRegular() { for (inti=0; i < numRatings; i++) { inttempUserId= dataset[i].user; inttempItemId= dataset[i].item; doubletempRate= dataset[i].rating;
if (tempPrediction < ratingLowerBound) { tempPrediction = ratingLowerBound; } elseif (tempPrediction > ratingUpperBound) { tempPrediction = ratingUpperBound; } // Of if
doubletempError= tempRate - tempPrediction; resultRsme += tempError * tempError; tempTestCount++; } // Of for i
return Math.sqrt(resultRsme / tempTestCount); }// Of rsme
/** * *********************** * Compute the MAE. * * @return MAE of the current factorization. * *********************** */ publicdoublemae() { doubleresultMae=0; inttempTestCount=0;
for (inti=0; i < numRatings; i++) { inttempUserIndex= dataset[i].user; inttempItemIndex= dataset[i].item; doubletempRate= dataset[i].rating;
if (tempPrediction < ratingLowerBound) { tempPrediction = ratingLowerBound; } // Of if if (tempPrediction > ratingUpperBound) { tempPrediction = ratingUpperBound; } // Of if
doubletempError= tempRate - tempPrediction;
resultMae += Math.abs(tempError); tempTestCount++; } // Of for i
return (resultMae / tempTestCount); }// Of mae
/** * *********************** * Test accuracy * * Out MAE and RSME * *********************** */ publicstaticvoidtestTrainingTesting(String paraFilename, int paraNumUsers, int paraNumItems, int paraNumRatings, double paraRatingLowerBound, double paraRatingUpperBound, int paraRounds) { try { // Step 1. read the training and testing data MatrixFactorizationtempMF=newMatrixFactorization(paraFilename, paraNumUsers, paraNumItems, paraNumRatings, paraRatingLowerBound, paraRatingUpperBound);
// Step 2. read the training and testing data tempMF.setParameters(5, 0.0001, 0.005);
// Step 3. update and predict System.out.println("Begin Training ! ! !"); tempMF.train(paraRounds);
doubletempMAE= tempMF.mae(); doubletempRSME= tempMF.rsme(); System.out.println("Finally, MAE = " + tempMAE + ", RSME = " + tempRSME); } catch (Exception e) { e.printStackTrace(); } // Of try }// Of testTrainingTesting
/** * ************************ * Test this class. * * @param args Not used now. * ************************ */ publicstaticvoidmain(String[] args) { testTrainingTesting("D:/Work/sampledata/movielens-943u1682m.txt", 943, 1682, 10000, 1, 5, 2000); }// Of main
/** * ******************** * The constructor. * ******************** */ publicTriple(int paraUser, int paraItem, double paraRating) { user = paraUser; item = paraItem; rating = paraRating; }// Of the first constructor
/** * ******************** * Show me. * ******************** */ public String toString() { return"" + user + ", " + item + ", " + rating; }// Of toString }// Of class Triple } // Of class MatrixFactorization