21.3. Matrix Factorization¶ Open the notebook in SageMaker Studio Lab
Matrix Factorization (Koren et al., 2009) is a well-established algorithm in the recommender systems literature. The first version of matrix factorization model is proposed by Simon Funk in a famous blog post in which he described the idea of factorizing the interaction matrix. It then became widely known due to the Netflix contest which was held in 2006. At that time, Netflix, a media-streaming and video-rental company, announced a contest to improve its recommender system performance. The best team that can improve on the Netflix baseline, i.e., Cinematch), by 10 percent would win a one million USD prize. As such, this contest attracted a lot of attention to the field of recommender system research. Subsequently, the grand prize was won by the BellKor’s Pragmatic Chaos team, a combined team of BellKor, Pragmatic Theory, and BigChaos (you do not need to worry about these algorithms now). Although the final score was the result of an ensemble solution (i.e., a combination of many algorithms), the matrix factorization algorithm played a critical role in the final blend. The technical report of the Netflix Grand Prize solution (Töscher et al., 2009) provides a detailed introduction to the adopted model. In this section, we will dive into the details of the matrix factorization model and its implementation.
21.3.1. The Matrix Factorization Model¶
Matrix factorization is a class of collaborative filtering models. Specifically, the model factorizes the user-item interaction matrix (e.g., rating matrix) into the product of two lower-rank matrices, capturing the low-rank structure of the user-item interactions.
Let
where
Then, we train the matrix factorization model by minimizing the mean squared error between predicted rating scores and real rating scores. The objective function is defined as follows:
where
An intuitive illustration of the matrix factorization model is shown below:
Fig. 21.3.1 Illustration of matrix factorization model¶
In the rest of this section, we will explain the implementation of matrix factorization and train the model on the MovieLens dataset.
21.3.2. Model Implementation¶
First, we implement the matrix factorization model described above. The
user and item latent factors can be created with the nn.Embedding
.
The input_dim
is the number of items/users and the output_dim
is
the dimension of the latent factors nn.Embedding
to create the user/item biases by setting the
output_dim
to one. In the forward
function, user and item ids
are used to look up the embeddings.
class MF(nn.Block):
def __init__(self, num_factors, num_users, num_items, **kwargs):
super(MF, self).__init__(**kwargs)
self.P = nn.Embedding(input_dim=num_users, output_dim=num_factors)
self.Q = nn.Embedding(input_dim=num_items, output_dim=num_factors)
self.user_bias = nn.Embedding(num_users, 1)
self.item_bias = nn.Embedding(num_items, 1)
def forward(self, user_id, item_id):
P_u = self.P(user_id)
Q_i = self.Q(item_id)
b_u = self.user_bias(user_id)
b_i = self.item_bias(item_id)
outputs = (P_u * Q_i).sum(axis=1) + np.squeeze(b_u) + np.squeeze(b_i)
return outputs.flatten()
21.3.3. Evaluation Measures¶
We then implement the RMSE (root-mean-square error) measure, which is commonly used to measure the differences between rating scores predicted by the model and the actually observed ratings (ground truth) (Gunawardana and Shani, 2015). RMSE is defined as:
where mx.metric
.
def evaluator(net, test_iter, devices):
rmse = mx.metric.RMSE() # Get the RMSE
rmse_list = []
for idx, (users, items, ratings) in enumerate(test_iter):
u = gluon.utils.split_and_load(users, devices, even_split=False)
i = gluon.utils.split_and_load(items, devices, even_split=False)
r_ui = gluon.utils.split_and_load(ratings, devices, even_split=False)
r_hat = [net(u, i) for u, i in zip(u, i)]
rmse.update(labels=r_ui, preds=r_hat)
rmse_list.append(rmse.get()[1])
return float(np.mean(np.array(rmse_list)))
21.3.4. Training and Evaluating the Model¶
In the training function, we adopt the
#@save
def train_recsys_rating(net, train_iter, test_iter, loss, trainer, num_epochs,
devices=d2l.try_all_gpus(), evaluator=None,
**kwargs):
timer = d2l.Timer()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 2],
legend=['train loss', 'test RMSE'])
for epoch in range(num_epochs):
metric, l = d2l.Accumulator(3), 0.
for i, values in enumerate(train_iter):
timer.start()
input_data = []
values = values if isinstance(values, list) else [values]
for v in values:
input_data.append(gluon.utils.split_and_load(v, devices))
train_feat = input_data[:-1] if len(values) > 1 else input_data
train_label = input_data[-1]
with autograd.record():
preds = [net(*t) for t in zip(*train_feat)]
ls = [loss(p, s) for p, s in zip(preds, train_label)]
[l.backward() for l in ls]
l += sum([l.asnumpy() for l in ls]).mean() / len(devices)
trainer.step(values[0].shape[0])
metric.add(l, values[0].shape[0], values[0].size)
timer.stop()
if len(kwargs) > 0: # It will be used in section AutoRec
test_rmse = evaluator(net, test_iter, kwargs['inter_mat'],
devices)
else:
test_rmse = evaluator(net, test_iter, devices)
train_l = l / (i + 1)
animator.add(epoch + 1, (train_l, test_rmse))
print(f'train loss {metric[0] / metric[1]:.3f}, '
f'test RMSE {test_rmse:.3f}')
print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
f'on {str(devices)}')
Finally, let’s put all things together and train the model. Here, we set the latent factor dimension to 30.
devices = d2l.try_all_gpus()
num_users, num_items, train_iter, test_iter = d2l.split_and_load_ml100k(
test_ratio=0.1, batch_size=512)
net = MF(30, num_users, num_items)
net.initialize(ctx=devices, force_reinit=True, init=mx.init.Normal(0.01))
lr, num_epochs, wd, optimizer = 0.002, 20, 1e-5, 'adam'
loss = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(), optimizer,
{"learning_rate": lr, 'wd': wd})
train_recsys_rating(net, train_iter, test_iter, loss, trainer, num_epochs,
devices, evaluator)
train loss 0.065, test RMSE 1.066
56722.8 examples/sec on [gpu(0), gpu(1)]
Below, we use the trained model to predict the rating that a user (ID 20) might give to an item (ID 30).
array([3.1718655], ctx=gpu(0))
21.3.5. Summary¶
The matrix factorization model is widely used in recommender systems. It can be used to predict ratings that a user might give to an item.
We can implement and train matrix factorization for recommender systems.
21.3.6. Exercises¶
Vary the size of latent factors. How does the size of latent factors influence the model performance?
Try different optimizers, learning rates, and weight decay rates.
Check the predicted rating scores of other users for a specific movie.