Skip to content

Commit

Permalink
fix random, admmslim, slimelastic
Browse files Browse the repository at this point in the history
  • Loading branch information
BishopLiu committed Feb 22, 2024
1 parent 4121d5c commit b6b90a4
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/admmslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def predict(self, interaction):
.flatten()
)

return add_noise(torch.from_numpy(r))
return add_noise(torch.from_numpy(r)).to(self.device)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def calculate_loss(self, interaction):
return torch.nn.Parameter(torch.zeros(1))

def predict(self, interaction):
return torch.rand(len(interaction)).squeeze(-1)
return torch.rand(len(interaction), device=self.device).squeeze(-1)

def full_sort_predict(self, interaction):
batch_user_num = interaction[self.USER_ID].shape[0]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/slimelastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def predict(self, interaction):
(self.interaction_matrix[user, :].multiply(self.item_similarity[:, item].T))
.sum(axis=1)
.getA1()
)
).to(self.device)

return r

Expand Down

0 comments on commit b6b90a4

Please sign in to comment.