Skip to content

Commit

Permalink
FIX: bug fix in case_study.py (for #976).
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Sep 23, 2021
1 parent 78e7d66 commit baa64be
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions recbole/utils/case_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import torch

from recbole.data.interaction import Interaction


@torch.no_grad()
def full_sort_scores(uid_series, model, test_data, device=None):
Expand All @@ -34,20 +36,19 @@ def full_sort_scores(uid_series, model, test_data, device=None):
torch.Tensor: the scores of all items for each user in uid_series.
"""
device = device or torch.device('cpu')
uid_series = np.array(uid_series)
uid_series = torch.tensor(uid_series)
uid_field = test_data.dataset.uid_field
dataset = test_data.dataset
model.eval()

if not test_data.is_sequential:
index = np.isin(test_data.user_df[uid_field].numpy(), uid_series)
input_interaction = test_data.user_df[index]
input_interaction = dataset.join(Interaction({uid_field: uid_series}))
history_item = test_data.uid2history_item[uid_series]
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
history_col = torch.cat(list(history_item))
history_index = history_row, history_col
else:
index = np.isin(dataset[uid_field].numpy(), uid_series)
_, index = (dataset[uid_field] == uid_series[:, None]).nonzero(as_tuple=True)
input_interaction = dataset[index]
history_index = None

Expand Down

0 comments on commit baa64be

Please sign in to comment.