diff --git a/recbole/utils/case_study.py b/recbole/utils/case_study.py index 3775b4a3e..7b7d8b5c7 100644 --- a/recbole/utils/case_study.py +++ b/recbole/utils/case_study.py @@ -43,7 +43,11 @@ def full_sort_scores(uid_series, model, test_data, device=None): if not test_data.is_sequential: input_interaction = dataset.join(Interaction({uid_field: uid_series})) - history_item = test_data.uid2history_item[uid_series] + if len(uid_series) == 1: + history_item = np.array([None]) + history_item[0] = test_data.uid2history_item[uid_series] + else: + history_item = test_data.uid2history_item[uid_series] history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) history_col = torch.cat(list(history_item)) history_index = history_row, history_col