Skip to content

Commit

Permalink
Merge pull request #2 from RUCAIBox/master
Browse files Browse the repository at this point in the history
update fork
  • Loading branch information
KyrieIrving24 authored Aug 22, 2020
2 parents fe13af3 + 27e3e0c commit 86fc6f4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
6 changes: 3 additions & 3 deletions recbox/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : [email protected]

# UPDATE
# @Time : 2020/8/19, 2020/8/18
# @Time : 2020/8/21, 2020/8/18
# @Author : Yupeng Hou, Yushuo Chen
# @email : [email protected], [email protected]

Expand Down Expand Up @@ -86,8 +86,8 @@ def set_batch_size(self, batch_size): # TODO batch size is useless...
def join(self, df):
return self.dataset.join(df)

def inter_matrix(self, form='coo'):
return self.dataset.inter_matrix(form=form)
def inter_matrix(self, form='coo', value_field=None):
return self.dataset.inter_matrix(form=form, value_field=value_field)


class GeneralDataLoader(AbstractDataLoader):
Expand Down
38 changes: 33 additions & 5 deletions recbox/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : [email protected]

# UPDATE:
# @Time : 2020/8/19, 2020/8/5, 2020/8/16
# @Time : 2020/8/21, 2020/8/5, 2020/8/16
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : [email protected], [email protected], [email protected]

Expand Down Expand Up @@ -68,6 +68,9 @@ def _from_scratch(self, config):
if self.config['fill_nan']:
self._fill_nan()

if self.config['normalize_field'] or self.config['normalize_all']:
self._normalize(self.config['normalize_field'])

def _restore_saved_dataset(self, saved_dataset):
if (saved_dataset is None) or (not os.path.isdir(saved_dataset)):
raise ValueError('filepath [{}] need to be a dir'.format(saved_dataset))
Expand Down Expand Up @@ -185,8 +188,6 @@ def _load_feat(self, filepath, source):
df.columns = field_names
df = df[columns]

# TODO fill nan in df

seq_separator = self.config['seq_separator']
def _token(df, field): pass
def _float(df, field): pass
Expand Down Expand Up @@ -222,6 +223,26 @@ def _fill_nan(self):
elif ftype.endswith('seq'):
self.logger.warning('feature [{}] (type: {}) probably has nan, while has not been filled.'.format(field, ftype))

def _normalize(self, fields=None):
if fields is None:
fields = list(self.field2type)
else:
for field in fields:
if field not in self.field2type:
raise ValueError('Field [{}] doesn\'t exist'.format(field))
elif self.field2type[field] != FeatureType.FLOAT:
self.logger.warn('{} is not a FLOAT feat, which will not be normalized.'.format(field))
for feat in [self.inter_feat, self.user_feat, self.item_feat]:
if feat is None:
continue
for field in feat:
if field in fields and self.field2type[field] == FeatureType.FLOAT:
lst = feat[field].values
mx, mn = max(lst), min(lst)
if mx == mn:
raise ValueError('All the same value in [{}] from [{}_feat]'.format(field, source))
feat[field] = (lst - mn) / (mx - mn)

def filter_by_inter_num(self, max_user_inter_num=None, min_user_inter_num=None,
max_item_inter_num=None, min_item_inter_num=None):
ban_users = self._get_illegal_ids_by_inter_num(source='user', max_num=max_user_inter_num,
Expand Down Expand Up @@ -644,13 +665,20 @@ def get_item_feature(self):
else:
return self.item_feat

def inter_matrix(self, form='coo'):
def inter_matrix(self, form='coo', value_field=None):
if not self.uid_field or not self.iid_field:
raise ValueError('dataset doesn\'t exist uid/iid, thus can not converted to sparse matrix')

uids = self.inter_feat[self.uid_field].values
iids = self.inter_feat[self.iid_field].values
data = np.ones(len(self.inter_feat))
if value_field is None:
data = np.ones(len(self.inter_feat))
else:
if value_field not in self.field2source:
raise ValueError('value_field [{}] not exist.'.format(value_field))
if self.field2source[value_field] != FeatureSource.INTERACTION:
raise ValueError('value_field [{}] can only be one of the interaction features'.format(value_field))
data = self.inter_feat[value_field].values
mat = coo_matrix((data, (uids, iids)), shape=(self.user_num, self.item_num))

if form == 'coo':
Expand Down
1 change: 1 addition & 0 deletions run_test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
'Test Criteo': {
'model': 'FM',
'dataset': 'criteo',
'normalize_all': True,
'group_by_user': False,
'epochs': 1,
'valid_metric': 'AUC',
Expand Down

0 comments on commit 86fc6f4

Please sign in to comment.