From b6b90a4a7ea6bba258397c8bab6a6a41a9872cd4 Mon Sep 17 00:00:00 2001 From: David <1282675518@qq.com> Date: Thu, 22 Feb 2024 15:33:56 +0800 Subject: [PATCH] fix random, admmslim, slimelastic --- recbole/model/general_recommender/admmslim.py | 2 +- recbole/model/general_recommender/random.py | 2 +- recbole/model/general_recommender/slimelastic.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/recbole/model/general_recommender/admmslim.py b/recbole/model/general_recommender/admmslim.py index 31aa710c1..f741b9df8 100644 --- a/recbole/model/general_recommender/admmslim.py +++ b/recbole/model/general_recommender/admmslim.py @@ -111,7 +111,7 @@ def predict(self, interaction): .flatten() ) - return add_noise(torch.from_numpy(r)) + return add_noise(torch.from_numpy(r)).to(self.device) def full_sort_predict(self, interaction): user = interaction[self.USER_ID].cpu().numpy() diff --git a/recbole/model/general_recommender/random.py b/recbole/model/general_recommender/random.py index e2b8d2e26..6faa2a51f 100644 --- a/recbole/model/general_recommender/random.py +++ b/recbole/model/general_recommender/random.py @@ -35,7 +35,7 @@ def calculate_loss(self, interaction): return torch.nn.Parameter(torch.zeros(1)) def predict(self, interaction): - return torch.rand(len(interaction)).squeeze(-1) + return torch.rand(len(interaction), device=self.device).squeeze(-1) def full_sort_predict(self, interaction): batch_user_num = interaction[self.USER_ID].shape[0] diff --git a/recbole/model/general_recommender/slimelastic.py b/recbole/model/general_recommender/slimelastic.py index 18b4033d0..07475f4db 100644 --- a/recbole/model/general_recommender/slimelastic.py +++ b/recbole/model/general_recommender/slimelastic.py @@ -99,7 +99,7 @@ def predict(self, interaction): (self.interaction_matrix[user, :].multiply(self.item_similarity[:, item].T)) .sum(axis=1) .getA1() - ) + ).to(self.device) return r