Skip to content

Commit

Permalink
Merge pull request #977 from chenyushuo/master
Browse files Browse the repository at this point in the history
FIX: bug fix in case_study.py (for #976).
  • Loading branch information
Sherry-XLL authored Sep 23, 2021
2 parents 952b96c + baa64be commit 434e3f0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/get_started/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Quick-start From API
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Before running a model, firstly you need to prepare and load data. To help users quickly get start,
RecBole has a build-in dataset **ml-100k** and you can directly use it. However, if you want to use other datasets, you can read
:doc:`../usage/running_new_dataset` for more information.
:doc:`../user_guide/usage/running_new_dataset` for more information.

Then, you need to set data config for data loading. You can create a `yaml` file called `test.yaml` and write the following settings:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

=========================================================

`HomePage <https://recbole.io/>`_ | `Docs <https://recbole.io/docs/>`_ | `GitHub <https://github.com/RUCAIBox/RecBole>`_ | `Datasets <https://github.com/RUCAIBox/RecDatasets>`_ | `v0.1.2 </docs/v0.1.2/>`_
`HomePage <https://recbole.io/>`_ | `Docs <https://recbole.io/docs/>`_ | `GitHub <https://github.com/RUCAIBox/RecBole>`_ | `Datasets <https://github.com/RUCAIBox/RecDatasets>`_ | `v0.1.2 </docs/v0.1.2/>`_ | `v0.2.0 </docs/v0.2.0/>`_

Introduction
-------------------------
Expand Down
9 changes: 5 additions & 4 deletions recbole/utils/case_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import torch

from recbole.data.interaction import Interaction


@torch.no_grad()
def full_sort_scores(uid_series, model, test_data, device=None):
Expand All @@ -34,20 +36,19 @@ def full_sort_scores(uid_series, model, test_data, device=None):
torch.Tensor: the scores of all items for each user in uid_series.
"""
device = device or torch.device('cpu')
uid_series = np.array(uid_series)
uid_series = torch.tensor(uid_series)
uid_field = test_data.dataset.uid_field
dataset = test_data.dataset
model.eval()

if not test_data.is_sequential:
index = np.isin(test_data.user_df[uid_field].numpy(), uid_series)
input_interaction = test_data.user_df[index]
input_interaction = dataset.join(Interaction({uid_field: uid_series}))
history_item = test_data.uid2history_item[uid_series]
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
history_col = torch.cat(list(history_item))
history_index = history_row, history_col
else:
index = np.isin(dataset[uid_field].numpy(), uid_series)
_, index = (dataset[uid_field] == uid_series[:, None]).nonzero(as_tuple=True)
input_interaction = dataset[index]
history_index = None

Expand Down

0 comments on commit 434e3f0

Please sign in to comment.