From ee14ba35bd5e980e7960bdb0721c6a4d4392717b Mon Sep 17 00:00:00 2001 From: Tian Zhen <1204216974@qq.com> Date: Fri, 8 Jul 2022 13:12:23 +0800 Subject: [PATCH] FIX: add parameter docs and shuffle interface --- docs/source/user_guide/config/environment_settings.rst | 6 +++--- recbole/data/dataloader/knowledge_dataloader.py | 8 ++++++++ recbole/properties/overall.yaml | 4 ++-- recbole/trainer/trainer.py | 4 ++++ 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/source/user_guide/config/environment_settings.rst b/docs/source/user_guide/config/environment_settings.rst index 19b8edf55..2096e9308 100644 --- a/docs/source/user_guide/config/environment_settings.rst +++ b/docs/source/user_guide/config/environment_settings.rst @@ -2,9 +2,8 @@ Environment settings =========================== Environment settings are designed to set basic parameters of running environment. -- ``gpu_id (int or str)`` : The id of GPU device. Defaults to ``0``. -- ``use_gpu (bool)`` : Whether or not to use GPU. If True, using GPU, else using CPU. - Defaults to ``True``. +- ``gpu_id (str)`` : The id of available GPU devices. Defaults to ``0``. +- ``worker (int)`` : The number of workers processing the data. - ``seed (int)`` : Random seed. Defaults to ``2020``. - ``state (str)`` : Logging level. Defaults to ``'INFO'``. Range in ``['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']``. @@ -39,3 +38,4 @@ Environment settings are designed to set basic parameters of running environment Defaults to ``False``. - ``wandb_project (str)``: The project to conduct experiment in W&B. Defaults to ``'recbole'``. +- ``shuffle (bool)``: Whether or not shuffle the training data before each epoch. Defaults to ``True``. \ No newline at end of file diff --git a/recbole/data/dataloader/knowledge_dataloader.py b/recbole/data/dataloader/knowledge_dataloader.py index fc26c5106..311a3ab63 100644 --- a/recbole/data/dataloader/knowledge_dataloader.py +++ b/recbole/data/dataloader/knowledge_dataloader.py @@ -99,6 +99,7 @@ def __init__(self, config, dataset, sampler, kg_sampler, shuffle=False): # using kg_sampler self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler, shuffle=True) + self.shuffle = False self.state = None self._dataset = dataset self.kg_iter, self.gen_iter = None, None @@ -155,3 +156,10 @@ def get_model(self, model): """Let the general_dataloader get the model, used for dynamic sampling. """ self.general_dataloader.get_model(model) + + def knowledge_shuffle(self, epoch_seed): + """Reset the seed to ensure that each subprocess generates the same index squence.""" + self.kg_dataloader.sampler.set_epoch(epoch_seed) + + if self.general_dataloader.shuffle: + self.general_dataloader.sampler.set_epoch(epoch_seed) diff --git a/recbole/properties/overall.yaml b/recbole/properties/overall.yaml index 4eabdb382..4951d1fbf 100644 --- a/recbole/properties/overall.yaml +++ b/recbole/properties/overall.yaml @@ -1,5 +1,5 @@ # general -gpu_id: '0,1,2,3' +gpu_id: '0' worker: 0 use_gpu: True seed: 2020 @@ -29,7 +29,7 @@ clip_grad_norm: ~ weight_decay: 0.0 loss_decimal_place: 4 require_pow: False -shuffle: False +shuffle: True # evaluation settings eval_args: diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 6b8487665..75e1abb8b 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -573,6 +573,8 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals interaction_state = KGDataLoaderState.RS else: interaction_state = KGDataLoaderState.KG + if not self.config['single_spec']: + train_data.knowledge_shuffle(epoch_idx) train_data.set_mode(interaction_state) if interaction_state in [KGDataLoaderState.RSKG, KGDataLoaderState.RS]: return super()._train_epoch(train_data, epoch_idx, show_progress=show_progress) @@ -593,6 +595,8 @@ def __init__(self, config, model): def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): # train rs + if not self.config['single_spec']: + train_data.knowledge_shuffle(epoch_idx) train_data.set_mode(KGDataLoaderState.RS) rs_total_loss = super()._train_epoch(train_data, epoch_idx, show_progress=show_progress)