Skip to content

Commit

Permalink
Merge pull request #4 from RUCAIBox/0.2.x
Browse files Browse the repository at this point in the history
merge from 0.2.x
  • Loading branch information
chenyushuo authored Nov 20, 2020
2 parents 9fa2719 + da5885a commit 91a6e92
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
4 changes: 3 additions & 1 deletion recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def data_preparation(config, dataset, save=False):

kwargs = {}
if config['training_neg_sample_num']:
es.neg_sample_by(config['training_neg_sample_num'])
train_distribution = config['training_neg_sample_distribution'] or 'uniform'
es.neg_sample_by(by=config['training_neg_sample_num'], distribution=train_distribution)
if model_type != ModelType.SEQUENTIAL:
sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution'])
else:
Expand All @@ -120,6 +121,7 @@ def data_preparation(config, dataset, save=False):
getattr(es, es_str[1])()
if 'sampler' not in locals():
sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution'])
sampler.set_distribution(es.neg_sample_args['distribution'])
kwargs['sampler'] = [sampler.set_phase('valid'), sampler.set_phase('test')]
kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
valid_data, test_data = dataloader_construct(
Expand Down
12 changes: 12 additions & 0 deletions recbole/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def get_used_ids(self):
last = used_item_id[phase] = cur
return used_item_id

def set_distribution(self, distribution):
if self.distribution == distribution:
return
self.distribution = distribution
self.random_list = self.get_random_list()

def set_phase(self, phase):
"""Get the sampler of corresponding phase.
Expand Down Expand Up @@ -295,6 +301,12 @@ def get_random_list(self):
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))

def set_distribution(self, distribution):
if self.distribution == distribution:
return
self.distribution = distribution
self.random_list = self.get_random_list()

def get_used_ids(self):
"""
Returns:
Expand Down

0 comments on commit 91a6e92

Please sign in to comment.