-
Notifications
You must be signed in to change notification settings - Fork 627
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from RUCAIBox/master
update fork
- Loading branch information
Showing
3 changed files
with
37 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
# @Email : [email protected] | ||
|
||
# UPDATE | ||
# @Time : 2020/8/19, 2020/8/18 | ||
# @Time : 2020/8/21, 2020/8/18 | ||
# @Author : Yupeng Hou, Yushuo Chen | ||
# @email : [email protected], [email protected] | ||
|
||
|
@@ -86,8 +86,8 @@ def set_batch_size(self, batch_size): # TODO batch size is useless... | |
def join(self, df): | ||
return self.dataset.join(df) | ||
|
||
def inter_matrix(self, form='coo'): | ||
return self.dataset.inter_matrix(form=form) | ||
def inter_matrix(self, form='coo', value_field=None): | ||
return self.dataset.inter_matrix(form=form, value_field=value_field) | ||
|
||
|
||
class GeneralDataLoader(AbstractDataLoader): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
# @Email : [email protected] | ||
|
||
# UPDATE: | ||
# @Time : 2020/8/19, 2020/8/5, 2020/8/16 | ||
# @Time : 2020/8/21, 2020/8/5, 2020/8/16 | ||
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen | ||
# @Email : [email protected], [email protected], [email protected] | ||
|
||
|
@@ -68,6 +68,9 @@ def _from_scratch(self, config): | |
if self.config['fill_nan']: | ||
self._fill_nan() | ||
|
||
if self.config['normalize_field'] or self.config['normalize_all']: | ||
self._normalize(self.config['normalize_field']) | ||
|
||
def _restore_saved_dataset(self, saved_dataset): | ||
if (saved_dataset is None) or (not os.path.isdir(saved_dataset)): | ||
raise ValueError('filepath [{}] need to be a dir'.format(saved_dataset)) | ||
|
@@ -185,8 +188,6 @@ def _load_feat(self, filepath, source): | |
df.columns = field_names | ||
df = df[columns] | ||
|
||
# TODO fill nan in df | ||
|
||
seq_separator = self.config['seq_separator'] | ||
def _token(df, field): pass | ||
def _float(df, field): pass | ||
|
@@ -222,6 +223,26 @@ def _fill_nan(self): | |
elif ftype.endswith('seq'): | ||
self.logger.warning('feature [{}] (type: {}) probably has nan, while has not been filled.'.format(field, ftype)) | ||
|
||
def _normalize(self, fields=None): | ||
if fields is None: | ||
fields = list(self.field2type) | ||
else: | ||
for field in fields: | ||
if field not in self.field2type: | ||
raise ValueError('Field [{}] doesn\'t exist'.format(field)) | ||
elif self.field2type[field] != FeatureType.FLOAT: | ||
self.logger.warn('{} is not a FLOAT feat, which will not be normalized.'.format(field)) | ||
for feat in [self.inter_feat, self.user_feat, self.item_feat]: | ||
if feat is None: | ||
continue | ||
for field in feat: | ||
if field in fields and self.field2type[field] == FeatureType.FLOAT: | ||
lst = feat[field].values | ||
mx, mn = max(lst), min(lst) | ||
if mx == mn: | ||
raise ValueError('All the same value in [{}] from [{}_feat]'.format(field, source)) | ||
feat[field] = (lst - mn) / (mx - mn) | ||
|
||
def filter_by_inter_num(self, max_user_inter_num=None, min_user_inter_num=None, | ||
max_item_inter_num=None, min_item_inter_num=None): | ||
ban_users = self._get_illegal_ids_by_inter_num(source='user', max_num=max_user_inter_num, | ||
|
@@ -644,13 +665,20 @@ def get_item_feature(self): | |
else: | ||
return self.item_feat | ||
|
||
def inter_matrix(self, form='coo'): | ||
def inter_matrix(self, form='coo', value_field=None): | ||
if not self.uid_field or not self.iid_field: | ||
raise ValueError('dataset doesn\'t exist uid/iid, thus can not converted to sparse matrix') | ||
|
||
uids = self.inter_feat[self.uid_field].values | ||
iids = self.inter_feat[self.iid_field].values | ||
data = np.ones(len(self.inter_feat)) | ||
if value_field is None: | ||
data = np.ones(len(self.inter_feat)) | ||
else: | ||
if value_field not in self.field2source: | ||
raise ValueError('value_field [{}] not exist.'.format(value_field)) | ||
if self.field2source[value_field] != FeatureSource.INTERACTION: | ||
raise ValueError('value_field [{}] can only be one of the interaction features'.format(value_field)) | ||
data = self.inter_feat[value_field].values | ||
mat = coo_matrix((data, (uids, iids)), shape=(self.user_num, self.item_num)) | ||
|
||
if form == 'coo': | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters