Skip to content

Commit

Permalink
Merge pull request #621 from deklanw/slimelastic
Browse files Browse the repository at this point in the history
add SLIMElastic
  • Loading branch information
ShanleiMu authored Mar 3, 2021
2 parents e9a864f + dc54dfe commit 02762ca
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 2 deletions.
1 change: 1 addition & 0 deletions recbole/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from recbole.model.general_recommender.neumf import NeuMF
from recbole.model.general_recommender.ngcf import NGCF
from recbole.model.general_recommender.pop import Pop
from recbole.model.general_recommender.slimelastic import SLIMElastic
from recbole.model.general_recommender.spectralcf import SpectralCF
from recbole.model.general_recommender.ease import EASE
from recbole.model.general_recommender.nncf import NNCF
110 changes: 110 additions & 0 deletions recbole/model/general_recommender/slimelastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
r"""
SLIMElastic
################################################
Reference:
10.1109/ICDM.2011.134
https://www.slideshare.net/MarkLevy/efficient-slides
Reference code:
https://github.com/KarypisLab/SLIM
https://github.com/MaurizioFD/RecSys2019_DeepLearning_Evaluation/blob/master/SLIM_ElasticNet/SLIMElasticNetRecommender.py
"""


from recbole.utils.enum_type import ModelType
import numpy as np
import scipy.sparse as sp
import torch
import warnings
from sklearn.linear_model import ElasticNet
from sklearn.exceptions import ConvergenceWarning

from recbole.utils import InputType
from recbole.model.abstract_recommender import GeneralRecommender


# https://github.com/RUCAIBox/RecBole/issues/622
def add_noise(t, mag=1e-5):
return t + mag * torch.rand(t.shape)


class SLIMElastic(GeneralRecommender):
input_type = InputType.POINTWISE
type = ModelType.TRADITIONAL

def __init__(self, config, dataset):
super().__init__(config, dataset)

# need at least one param
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

X = dataset.inter_matrix(
form='csr').astype(np.float32)

X = X.tolil()
self.interaction_matrix = X

hide_item = config['hide_item']
alpha = config['alpha']
l1_ratio = config['l1_ratio']
positive_only = config['positive_only']

model = ElasticNet(alpha=alpha, l1_ratio=l1_ratio,
positive=positive_only,
fit_intercept=False,
copy_X=False,
precompute=True,
selection='random',
max_iter=100,
tol=1e-4)

item_coeffs = []

# ignore ConvergenceWarnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)

for j in range(X.shape[1]):
# target column
r = X[:, j]

if hide_item:
# set item column to 0
X[:, j] = 0

# fit the model
model.fit(X, r.todense().getA1())

# store the coefficients
coeffs = model.sparse_coef_

item_coeffs.append(coeffs)

if hide_item:
# reattach column if removed
X[:, j] = r

self.item_similarity = sp.vstack(item_coeffs).T

def forward(self):
pass

def calculate_loss(self, interaction):
return torch.nn.Parameter(torch.zeros(1))

def predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()
item = interaction[self.ITEM_ID].cpu().numpy()

r = torch.from_numpy((self.interaction_matrix[user, :].multiply(
self.item_similarity[:, item].T)).sum(axis=1).getA1())

return add_noise(r)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()

r = self.interaction_matrix[user, :] @ self.item_similarity
r = torch.from_numpy(r.todense().getA1())

return add_noise(r)
4 changes: 4 additions & 0 deletions recbole/properties/model/SLIMElastic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
alpha: 0.2
l1_ratio: 0.02
positive_only: False
hide_item: True
10 changes: 8 additions & 2 deletions run_test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@
'model': 'LINE',
'dataset': 'ml-100k',
},
'Test SLIMElastic': {
'model': 'SLIMElastic',
'dataset': 'ml-100k',
},
'Test EASE': {
'model': 'EASE',
'dataset': 'ml-100k',
Expand Down Expand Up @@ -359,13 +363,15 @@ def run_test_examples():
for idx, example in enumerate(test_examples.keys()):
if example in closed_examples:
continue
print('\n\n Begin to run %d / %d example: %s \n\n' % (idx + 1, n_examples, example))
print('\n\n Begin to run %d / %d example: %s \n\n' %
(idx + 1, n_examples, example))
try:
config_dict = test_examples[example]
if 'epochs' not in config_dict:
config_dict['epochs'] = 1
run_recbole(config_dict=config_dict, saved=False)
print('\n\n Running %d / %d example successfully: %s \n\n' % (idx + 1, n_examples, example))
print('\n\n Running %d / %d example successfully: %s \n\n' %
(idx + 1, n_examples, example))
success_examples.append(example)
except Exception:
print(traceback.format_exc())
Expand Down
7 changes: 7 additions & 0 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def test_NNCF(self):
quick_test(config_dict)


def test_slimelastic(self):
config_dict = {
'model': 'SLIMElastic',
}
quick_test(config_dict)


class TestContextRecommender(unittest.TestCase):
# todo: more complex context information should be test, such as criteo dataset
Expand Down Expand Up @@ -754,5 +760,6 @@ def test_kgnnls_with_concat(self):
quick_test(config_dict)



if __name__ == '__main__':
unittest.main()

0 comments on commit 02762ca

Please sign in to comment.