diff --git a/recbole/data/utils.py b/recbole/data/utils.py index 91cb08727..accf76dfc 100644 --- a/recbole/data/utils.py +++ b/recbole/data/utils.py @@ -94,7 +94,8 @@ def data_preparation(config, dataset, save=False): kwargs = {} if config['training_neg_sample_num']: - es.neg_sample_by(config['training_neg_sample_num']) + train_distribution = config['training_neg_sample_distribution'] or 'uniform' + es.neg_sample_by(by=config['training_neg_sample_num'], distribution=train_distribution) if model_type != ModelType.SEQUENTIAL: sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution']) else: @@ -120,6 +121,7 @@ def data_preparation(config, dataset, save=False): getattr(es, es_str[1])() if 'sampler' not in locals(): sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution']) + sampler.set_distribution(es.neg_sample_args['distribution']) kwargs['sampler'] = [sampler.set_phase('valid'), sampler.set_phase('test')] kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args) valid_data, test_data = dataloader_construct( diff --git a/recbole/sampler/sampler.py b/recbole/sampler/sampler.py index 150324c84..e16a417fa 100644 --- a/recbole/sampler/sampler.py +++ b/recbole/sampler/sampler.py @@ -155,6 +155,12 @@ def get_used_ids(self): last = used_item_id[phase] = cur return used_item_id + def set_distribution(self, distribution): + if self.distribution == distribution: + return + self.distribution = distribution + self.random_list = self.get_random_list() + def set_phase(self, phase): """Get the sampler of corresponding phase. @@ -295,6 +301,12 @@ def get_random_list(self): else: raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution)) + def set_distribution(self, distribution): + if self.distribution == distribution: + return + self.distribution = distribution + self.random_list = self.get_random_list() + def get_used_ids(self): """ Returns: