Skip to content

Commit

Permalink
Merge pull request #2 from RUCAIBox/master
Browse files Browse the repository at this point in the history
up to date
  • Loading branch information
RichardHGL authored Sep 1, 2020
2 parents b460eb5 + be21361 commit 1541529
Show file tree
Hide file tree
Showing 38 changed files with 1,233 additions and 252 deletions.
9 changes: 4 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from logging import getLogger
from recbox.config import Config
from recbox.data import Dataset, data_preparation
from recbox.model.general_recommender.bprmf import BPRMF
from recbox.trainer import Trainer
from recbox.data import create_dataset, data_preparation
from recbox.trainer import get_trainer
from recbox.utils import init_logger, get_model

config = Config('properties/overall.config')
config.init()
init_logger(config)
logger = getLogger()

dataset = Dataset(config)
dataset = create_dataset(config)
logger.info(dataset)

# If you want to customize the evaluation setting,
Expand All @@ -20,7 +19,7 @@
model = get_model(config['model'])(config, train_data).to(config['device'])
logger.info(model)

trainer = Trainer(config, model)
trainer = get_trainer(config['MODEL_TYPE'])(config, model)

# trainer.resume_checkpoint('saved/model_best.pth')
best_valid_score, _ = trainer.fit(train_data, valid_data)
Expand Down
1 change: 1 addition & 0 deletions properties/dataset/amazon-game.config
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
########define the UIRT columns
USER_ID_FIELD='reviewerID'
ITEM_ID_FIELD='asin'
RATING_FIELD='overall'
NEG_PREFIX='neg_'
LABEL_FIELD='label'

Expand Down
22 changes: 22 additions & 0 deletions properties/dataset/kgdata_example.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[data]

########define the UIRT columns
USER_ID_FIELD='user_id'
ITEM_ID_FIELD='item_id'
NEG_PREFIX='neg_'
HEAD_ENTITY_ID_FIELD='head_id'
TAIL_ENTITY_ID_FIELD='tail_id'
RELATION_ID_FIELD='relation_id'
ENTITY_ID_FIELD='entity_id'
LABEL_FIELD='label'

#########select load columns
load_col={'inter': ['user_id', 'item_id', 'rating'], 'kg': ['head_id', 'relation_id', 'tail_id'], 'link': ['item_id', 'entity_id']}

########data separator
field_separator='\t'
seq_separator=' '

########data filter
lowest_val={'rating':3}
drop_filter_field=True
1 change: 1 addition & 0 deletions properties/dataset/ml-100k.config
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ USER_ID_FIELD='user_id'
ITEM_ID_FIELD='item_id'
NEG_PREFIX='neg_'
LABEL_FIELD='label'
RATING_FIELD='rating'

#########select load columns
# USER_ID_FIELD & ITEM_ID_FIELD can be omitted
Expand Down
1 change: 1 addition & 0 deletions properties/dataset/ml-1m.config
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ USER_ID_FIELD='user_id'
ITEM_ID_FIELD='item_id'
NEG_PREFIX='neg_'
LABEL_FIELD='label'
RATING_FIELD='rating'

#########select load columns
load_col={'inter': ['user_id', 'item_id', 'rating']}
Expand Down
1 change: 1 addition & 0 deletions properties/dataset/yelp.config
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ USER_ID_FIELD='user_id'
ITEM_ID_FIELD='business'
NEG_PREFIX='neg_'
LABEL_FIELD='label'
RATING_FIELD='stars'

#########select load columns
load_col={'inter': ['user_id', 'business', 'stars']}
Expand Down
7 changes: 5 additions & 2 deletions properties/model/DMF.config
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[model]

layers = [64,64]
reg_weight = 0.0
# The dimensions of the last layer of users and items must be the same
inter_matrix_type = '01'
# inter_matrix_type = 'rating'
user_layers_dim = [64, 64]
item_layers_dim = [64, 64]
min_y_hat = 1e-6
11 changes: 11 additions & 0 deletions properties/model/FPMC.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[model]

MAX_ITEM_LIST_LENGTH=50
TARGET_PREFIX='target_'
LIST_SUFFIX='_list'
ITEM_LIST_LENGTH_FIELD='item_length'
USER_ID_FIELD='user_id'
NEG_PREFIX='neg_'

embedding_size=64
neg_count=1
2 changes: 1 addition & 1 deletion properties/model/GRU4Rec.config
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
MAX_ITEM_LIST_LENGTH=50
TARGET_PREFIX='target_'
LIST_SUFFIX='_list'
POSITION_FIELD='position_id'
ITEM_LIST_LENGTH_FIELD='item_length'

embedding_size=64
hidden_size=32
num_layers=1
dropout=0
5 changes: 5 additions & 0 deletions properties/model/LightGCN.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[model]

embedding_size=64
layers=3
delay=1e-4
11 changes: 11 additions & 0 deletions properties/model/NARM.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[model]

MAX_ITEM_LIST_LENGTH=50
TARGET_PREFIX='target_'
LIST_SUFFIX='_list'
ITEM_LIST_LENGTH_FIELD='item_length'

embedding_size=64
hidden_size=32
n_layers=1
dropout=[0.25,0.5]
5 changes: 4 additions & 1 deletion properties/overall.config
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ eval_setting=RO_RS, full
group_by_user=True
split_ratio=[0.8, 0.1, 0.1]
# leave_one_num=2
real_time_neg_sampling=True

# DataLoader args
real_time_process=True

# training
epochs=400
Expand All @@ -29,6 +31,7 @@ eval_step=1
valid_metric=MRR@10
stopping_step=10
seed=2020
training_neg_sample_num=1

# evaluating
metrics=["Recall", "MRR","NDCG"]
Expand Down
70 changes: 43 additions & 27 deletions recbox/config/eval_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : [email protected]

# UPDATE:
# @Time : 2020/8/19 18:56
# @Author : Yupeng Hou
# @Email : [email protected]
# @Time : 2020/8/19 18:56, 2020/8/31
# @Author : Yupeng Hou, Yushuo Chen
# @Email : [email protected], [email protected]


class EvalSetting(object):
Expand All @@ -23,12 +23,29 @@ def __init__(self, config):
setattr(self, args, config[args])

def __str__(self):
info = 'Evaluation Setting:\n'
info += ('\tGroup by {}\n'.format(self.group_field) if self.group_field is not None else '\tNo Grouping\n')
info += ('\tOrdering: {}\n'.format(self.ordering_args) if (self.ordering_args is not None and self.ordering_args['strategy'] != 'none') else '\tNo Ordering\n')
info += ('\tSplitting: {}\n'.format(self.split_args) if (self.split_args is not None and self.split_args['strategy'] != 'none') else '\tNo Splitting\n')
info += ('\tNegative Sampling: {}'.format(self.neg_sample_args) if (self.neg_sample_args is not None and self.neg_sample_args['strategy'] != 'none') else '\tNo Negative Sampling\n')
return info
info = ['Evaluation Setting:']

if self.group_field:
info.append('Group by {}'.format(self.group_field))
else:
info.append('No Grouping')

if self.ordering_args is not None and self.ordering_args['strategy'] != 'none':
info.append('Ordering: {}'.format(self.ordering_args))
else:
info.append('No Ordering')

if self.split_args is not None and self.split_args['strategy'] != 'none':
info.append('Splitting: {}'.format(self.split_args))
else:
info.append('No Splitting')

if self.neg_sample_args is not None and self.neg_sample_args['strategy'] != 'none':
info.append('Negative Sampling: {}'.format(self.neg_sample_args))
else:
info.append('No Negative Sampling')

return '\n\t'.join(info)

def __repr__(self):
return self.__str__()
Expand Down Expand Up @@ -135,24 +152,23 @@ def split_by_value(self, field, values, ascending=True):
Args:
strategy (str): Either 'none', 'full' or 'by'.
by (int): Negative Sampling `by` neg cases for one pos case.
real_time (bool): real time negative sampling if True, else negative cases will be pre-sampled and stored.
distribution (str): distribution of sampler, either 'uniform' or 'popularity'.
Example:
>>> es.neg_sample_to(100, real_time=True)
>>> es.neg_sample_to(100)
>>> es.neg_sample_by(1)
"""
def set_neg_sampling(self, strategy='none', real_time=False, distribution='uniform', **kwargs):
def set_neg_sampling(self, strategy='none', distribution='uniform', **kwargs):
legal_strategy = {'none', 'full', 'by'}
if strategy not in legal_strategy:
raise ValueError('Negative Sampling Strategy [{}] should in {}'.format(strategy, list(legal_strategy)))
if strategy == 'full' and distribution != 'uniform':
raise ValueError('Full Sort can not be sampled by distribution [{}]'.format(distribution))
self.neg_sample_args = {'strategy': strategy, 'real_time': real_time, 'distribution': distribution}
self.neg_sample_args = {'strategy': strategy, 'distribution': distribution}
self.neg_sample_args.update(kwargs)

def neg_sample_by(self, by, real_time=False, distribution='uniform'):
self.set_neg_sampling(strategy='by', by=by, real_time=real_time, distribution=distribution)
def neg_sample_by(self, by, distribution='uniform'):
self.set_neg_sampling(strategy='by', by=by, distribution=distribution)

r"""Presets
Expand All @@ -161,13 +177,13 @@ def neg_sample_by(self, by, real_time=False, distribution='uniform'):
full: all non-ground-truth items
uni: uniform sampling pop: popularity sampling neg_sample_by 100 by default.
"""
def RO_RS(self, ratios=[0.8, 0.1, 0.1], group_by_user=True):
def RO_RS(self, ratios=(0.8, 0.1, 0.1), group_by_user=True):
if group_by_user:
self.group_by_user()
self.random_ordering()
self.split_by_ratio(ratios)

def TO_RS(self, ratios=[0.8, 0.1, 0.1], group_by_user=True):
def TO_RS(self, ratios=(0.8, 0.1, 0.1), group_by_user=True):
if group_by_user:
self.group_by_user()
self.temporal_ordering()
Expand All @@ -185,17 +201,17 @@ def TO_LS(self, leave_one_num=1, group_by_user=True):
self.temporal_ordering()
self.leave_one_out(leave_one_num=leave_one_num)

def uni100(self, real_time=False):
self.neg_sample_by(100, real_time=real_time)
def uni100(self):
self.neg_sample_by(100)

def pop100(self, real_time=False):
self.neg_sample_by(100, real_time=real_time, distribution='popularity')
def pop100(self):
self.neg_sample_by(100, distribution='popularity')

def uni1000(self, real_time=False):
self.neg_sample_by(1000, real_time=real_time)
def uni1000(self):
self.neg_sample_by(1000)

def pop1000(self, real_time=False):
self.neg_sample_by(1000, real_time=real_time, distribution='popularity')
def pop1000(self):
self.neg_sample_by(1000, distribution='popularity')

def full(self, real_time=True):
self.set_neg_sampling(strategy='full', real_time=real_time)
def full(self):
self.set_neg_sampling(strategy='full')
Loading

0 comments on commit 1541529

Please sign in to comment.