Skip to content

Commit

Permalink
FIX: add parameter docs and shuffle interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Ethan-TZ committed Jul 8, 2022
1 parent ac1859a commit ee14ba3
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
6 changes: 3 additions & 3 deletions docs/source/user_guide/config/environment_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ Environment settings
===========================
Environment settings are designed to set basic parameters of running environment.

- ``gpu_id (int or str)`` : The id of GPU device. Defaults to ``0``.
- ``use_gpu (bool)`` : Whether or not to use GPU. If True, using GPU, else using CPU.
Defaults to ``True``.
- ``gpu_id (str)`` : The id of available GPU devices. Defaults to ``0``.
- ``worker (int)`` : The number of workers processing the data.
- ``seed (int)`` : Random seed. Defaults to ``2020``.
- ``state (str)`` : Logging level. Defaults to ``'INFO'``.
Range in ``['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']``.
Expand Down Expand Up @@ -39,3 +38,4 @@ Environment settings are designed to set basic parameters of running environment
Defaults to ``False``.
- ``wandb_project (str)``: The project to conduct experiment in W&B.
Defaults to ``'recbole'``.
- ``shuffle (bool)``: Whether or not shuffle the training data before each epoch. Defaults to ``True``.
8 changes: 8 additions & 0 deletions recbole/data/dataloader/knowledge_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self, config, dataset, sampler, kg_sampler, shuffle=False):
# using kg_sampler
self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler, shuffle=True)

self.shuffle = False
self.state = None
self._dataset = dataset
self.kg_iter, self.gen_iter = None, None
Expand Down Expand Up @@ -155,3 +156,10 @@ def get_model(self, model):
"""Let the general_dataloader get the model, used for dynamic sampling.
"""
self.general_dataloader.get_model(model)

def knowledge_shuffle(self, epoch_seed):
"""Reset the seed to ensure that each subprocess generates the same index squence."""
self.kg_dataloader.sampler.set_epoch(epoch_seed)

if self.general_dataloader.shuffle:
self.general_dataloader.sampler.set_epoch(epoch_seed)
4 changes: 2 additions & 2 deletions recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# general
gpu_id: '0,1,2,3'
gpu_id: '0'
worker: 0
use_gpu: True
seed: 2020
Expand Down Expand Up @@ -29,7 +29,7 @@ clip_grad_norm: ~
weight_decay: 0.0
loss_decimal_place: 4
require_pow: False
shuffle: False
shuffle: True

# evaluation settings
eval_args:
Expand Down
4 changes: 4 additions & 0 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,8 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals
interaction_state = KGDataLoaderState.RS
else:
interaction_state = KGDataLoaderState.KG
if not self.config['single_spec']:
train_data.knowledge_shuffle(epoch_idx)
train_data.set_mode(interaction_state)
if interaction_state in [KGDataLoaderState.RSKG, KGDataLoaderState.RS]:
return super()._train_epoch(train_data, epoch_idx, show_progress=show_progress)
Expand All @@ -593,6 +595,8 @@ def __init__(self, config, model):

def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
# train rs
if not self.config['single_spec']:
train_data.knowledge_shuffle(epoch_idx)
train_data.set_mode(KGDataLoaderState.RS)
rs_total_loss = super()._train_epoch(train_data, epoch_idx, show_progress=show_progress)

Expand Down

0 comments on commit ee14ba3

Please sign in to comment.