diff --git a/docs/source/user_guide/config/environment_settings.rst b/docs/source/user_guide/config/environment_settings.rst index dc4bb07b8..51b311823 100644 --- a/docs/source/user_guide/config/environment_settings.rst +++ b/docs/source/user_guide/config/environment_settings.rst @@ -8,6 +8,8 @@ Environment settings are designed to set basic parameters of running environment - ``seed (int)`` : Random seed. Defaults to ``2020``. - ``state (str)`` : Logging level. Defaults to ``'INFO'``. Range in ``['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']``. +- ``encoding (str)``: Encoding to use for reading atomic files. Defaults to ``'utf-8'``. + The available encoding can be found in `here `__. - ``reproducibility (bool)`` : If True, the tool will use deterministic convolution algorithms, which makes the result reproducible. If False, the tool will benchmark multiple convolution algorithms and select the fastest one, diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index fe169db9c..f69997097 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -108,6 +108,7 @@ class NegSampleDataLoader(AbstractDataLoader): sampler (Sampler): The sampler of dataloader. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ + def __init__(self, config, dataset, sampler, shuffle=True): super().__init__(config, dataset, sampler, shuffle=shuffle) diff --git a/recbole/data/dataloader/general_dataloader.py b/recbole/data/dataloader/general_dataloader.py index f9ab55121..aaefdb6e5 100644 --- a/recbole/data/dataloader/general_dataloader.py +++ b/recbole/data/dataloader/general_dataloader.py @@ -73,6 +73,7 @@ class NegSampleEvalDataLoader(NegSampleDataLoader): sampler (Sampler): The sampler of dataloader. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ + def __init__(self, config, dataset, sampler, shuffle=False): self._set_neg_sample_args(config, dataset, InputType.POINTWISE, config['eval_neg_sample_args']) if self.neg_sample_args['strategy'] == 'by': @@ -193,7 +194,7 @@ def _set_user_property(self, uid, used_item, positive_item): if uid is None: return history_item = used_item - positive_item - self.uid2positive_item[uid] = torch.tensor(list(positive_item), dtype=torch.int64) + self.uid2positive_item[uid] = torch.tensor(list(positive_item), dtype=torch.int64) self.uid2items_num[uid] = len(positive_item) self.uid2history_item[uid] = torch.tensor(list(history_item), dtype=torch.int64) @@ -222,7 +223,7 @@ def _next_batch_data(self): if not self.is_sequential: user_df = self.user_df[self.pr:self.pr + self.step] uid_list = list(user_df[self.uid_field]) - + history_item = self.uid2history_item[uid_list] positive_item = self.uid2positive_item[uid_list] @@ -241,4 +242,4 @@ def _next_batch_data(self): positive_i = interaction[self.iid_field] self.pr += self.step - return interaction, None, positive_u, positive_i + return interaction, None, positive_u, positive_i diff --git a/recbole/data/dataset/customized_dataset.py b/recbole/data/dataset/customized_dataset.py index 0b93e4306..f002c79d2 100644 --- a/recbole/data/dataset/customized_dataset.py +++ b/recbole/data/dataset/customized_dataset.py @@ -52,6 +52,7 @@ class DIENDataset(SequentialDataset): neg_item_list_field (str): Field name for negative item sequence. neg_item_list (torch.tensor): all users' negative item history sequence. """ + def __init__(self, config): super().__init__(config) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index d9e967595..f016ab2fb 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -209,8 +209,10 @@ def _get_download_url(self, url_file, allow_none=False): elif allow_none: return None else: - raise ValueError(f'Neither [{self.dataset_path}] exists in the device' - f'nor [{self.dataset_name}] a known dataset name.') + raise ValueError( + f'Neither [{self.dataset_path}] exists in the device' + f'nor [{self.dataset_name}] a known dataset name.' + ) def _download(self): url = self._get_download_url('url') @@ -404,7 +406,8 @@ def _load_feat(self, filepath, source): columns = [] usecols = [] dtype = {} - with open(filepath, 'r') as f: + encoding = self.config['encoding'] + with open(filepath, 'r', encoding=encoding) as f: head = f.readline()[:-1] for field_type in head.split(field_separator): field, ftype = field_type.split(':') @@ -429,7 +432,9 @@ def _load_feat(self, filepath, source): self.logger.warning(f'No columns has been loaded from [{source}]') return None - df = pd.read_csv(filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype) + df = pd.read_csv( + filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding + ) df.columns = columns seq_separator = self.config['seq_separator'] @@ -462,15 +467,19 @@ def _init_alias(self): if alias_name_1 != alias_name_2: intersect = np.intersect1d(alias_1, alias_2, assume_unique=True) if len(intersect) > 0: - raise ValueError(f'`alias_of_{alias_name_1}` and `alias_of_{alias_name_2}` ' - f'should not have the same field {list(intersect)}.') + raise ValueError( + f'`alias_of_{alias_name_1}` and `alias_of_{alias_name_2}` ' + f'should not have the same field {list(intersect)}.' + ) self._rest_fields = self.token_like_fields for alias_name, alias in self.alias.items(): isin = np.isin(alias, self._rest_fields, assume_unique=True) if isin.all() is False: - raise ValueError(f'`alias_of_{alias_name}` should not contain ' - f'non-token-like field {list(alias[~isin])}.') + raise ValueError( + f'`alias_of_{alias_name}` should not contain ' + f'non-token-like field {list(alias[~isin])}.' + ) self._rest_fields = np.setdiff1d(self._rest_fields, alias, assume_unique=True) def _user_item_feat_preparation(self): @@ -484,7 +493,7 @@ def _user_item_feat_preparation(self): if self.item_feat is not None: new_item_df = pd.DataFrame({self.iid_field: np.arange(self.item_num)}) self.item_feat = pd.merge(new_item_df, self.item_feat, on=self.iid_field, how='left') - self.logger.debug(set_color('ordering item features by user id.', 'green')) + self.logger.debug(set_color('ordering item features by item id.', 'green')) def _preload_weight_matrix(self): """Transfer preload weight features into :class:`numpy.ndarray` with shape ``[id_token_length]`` @@ -592,6 +601,7 @@ def _normalize(self): for field in fields: for feat in self.field2feats(field): + def norm(arr): mx, mn = max(arr), min(arr) if mx == mn: @@ -675,14 +685,18 @@ def _filter_by_inter_num(self): item_inter_num = Counter(self.inter_feat[self.iid_field].values) if item_inter_num_interval else Counter() while True: - ban_users = self._get_illegal_ids_by_inter_num(field=self.uid_field, - feat=self.user_feat, - inter_num=user_inter_num, - inter_interval=user_inter_num_interval) - ban_items = self._get_illegal_ids_by_inter_num(field=self.iid_field, - feat=self.item_feat, - inter_num=item_inter_num, - inter_interval=item_inter_num_interval) + ban_users = self._get_illegal_ids_by_inter_num( + field=self.uid_field, + feat=self.user_feat, + inter_num=user_inter_num, + inter_interval=user_inter_num_interval + ) + ban_items = self._get_illegal_ids_by_inter_num( + field=self.iid_field, + feat=self.item_feat, + inter_num=item_inter_num, + inter_interval=item_inter_num_interval + ) if len(ban_users) == 0 and len(ban_items) == 0: break @@ -722,7 +736,8 @@ def _get_illegal_ids_by_inter_num(self, field, feat, inter_num, inter_interval=N set: illegal ids, whose inter num out of inter_intervals. """ self.logger.debug( - set_color('get_illegal_ids_by_inter_num', 'blue') + f': field=[{field}], inter_interval=[{inter_interval}]') + set_color('get_illegal_ids_by_inter_num', 'blue') + f': field=[{field}], inter_interval=[{inter_interval}]' + ) if inter_interval is not None: if len(inter_interval) > 1: diff --git a/recbole/data/dataset/sequential_dataset.py b/recbole/data/dataset/sequential_dataset.py index 45327c123..82146b123 100644 --- a/recbole/data/dataset/sequential_dataset.py +++ b/recbole/data/dataset/sequential_dataset.py @@ -170,7 +170,7 @@ def inter_matrix(self, form='coo', value_field=None): for field in l1_inter_dict: if field != self.uid_field and field + list_suffix in l1_inter_dict: candidate_field_set.add(field) - new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field + list_suffix][:,0]]) + new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field + list_suffix][:, 0]]) elif (not field.endswith(list_suffix)) and (field != self.item_list_length_field): new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field]]) local_inter_feat = Interaction(new_dict) diff --git a/recbole/data/utils.py b/recbole/data/utils.py index e6ebdf037..8ff65afdd 100644 --- a/recbole/data/utils.py +++ b/recbole/data/utils.py @@ -112,12 +112,12 @@ def data_preparation(config, dataset, save=False): test_data = get_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False) logger.info( set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' + - set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': '+ + set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': ' + set_color(f'[{config["neg_sampling"]}]', 'yellow') ) logger.info( set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' + - set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': '+ + set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': ' + set_color(f'[{config["eval_args"]}]', 'yellow') ) if save: diff --git a/recbole/evaluator/collector.py b/recbole/evaluator/collector.py index 44b7bf186..9f856a12d 100644 --- a/recbole/evaluator/collector.py +++ b/recbole/evaluator/collector.py @@ -16,7 +16,9 @@ import torch import copy + class DataStruct(object): + def __init__(self): self._data_dict = {} @@ -64,6 +66,7 @@ class Collector(object): This class is only used in Trainer. """ + def __init__(self, config): self.config = config self.data_struct = DataStruct() @@ -123,7 +126,9 @@ def _average_rank(self, scores): return avg_rank - def eval_batch_collect(self, scores_tensor: torch.Tensor, interaction, positive_u: torch.Tensor, positive_i: torch.Tensor): + def eval_batch_collect( + self, scores_tensor: torch.Tensor, interaction, positive_u: torch.Tensor, positive_i: torch.Tensor + ): """ Collect the evaluation resource from batched eval data and batched model output. Args: scores_tensor (Torch.Tensor): the output tensor of model with the shape of `(N, )` @@ -173,7 +178,6 @@ def eval_batch_collect(self, scores_tensor: torch.Tensor, interaction, positive_ self.data_struct.update_tensor('data.label', interaction[self.label_field].to(self.device)) def model_collect(self, model: torch.nn.Module): - """ Collect the evaluation resource from model. Args: model (nn.Module): the trained recommendation model. diff --git a/recbole/evaluator/evaluator.py b/recbole/evaluator/evaluator.py index 08eeffc46..d83dda333 100644 --- a/recbole/evaluator/evaluator.py +++ b/recbole/evaluator/evaluator.py @@ -39,4 +39,3 @@ def evaluate(self, dataobject: DataStruct): metric_val = self.metric_class[metric].calculate_metric(dataobject) result_dict.update(metric_val) return result_dict - diff --git a/recbole/evaluator/metrics.py b/recbole/evaluator/metrics.py index 65cc7d007..e6867dc36 100644 --- a/recbole/evaluator/metrics.py +++ b/recbole/evaluator/metrics.py @@ -49,6 +49,7 @@ class Hit(TopkMetric): :math:`\delta(·)` is an indicator function. :math:`\delta(b)` = 1 if :math:`b` is true and 0 otherwise. :math:`\emptyset` denotes the empty set. """ + def __init__(self, config): super().__init__(config) @@ -74,6 +75,7 @@ class MRR(TopkMetric): :math:`{rank}_{u}^{*}` is the rank position of the first relevant item found by an algorithm for a user :math:`u`. """ + def __init__(self, config): super().__init__(config) @@ -110,6 +112,7 @@ class MAP(TopkMetric): :math:`\hat{R}_{j}(u)` is the j-th item in the recommendation list of \hat R (u)). """ + def __init__(self, config): super().__init__(config) self.config = config @@ -143,6 +146,7 @@ class Recall(TopkMetric): :math:`|R(u)|` represents the item count of :math:`R(u)`. """ + def __init__(self, config): super().__init__(config) @@ -169,6 +173,7 @@ class NDCG(TopkMetric): :math:`\delta(·)` is an indicator function. """ + def __init__(self, config): super().__init__(config) @@ -208,6 +213,7 @@ class Precision(TopkMetric): :math:`|\hat R(u)|` represents the item count of :math:`\hat R(u)`. """ + def __init__(self, config): super().__init__(config) @@ -223,6 +229,7 @@ def metric_info(self, pos_index): # CTR Metrics + class GAUC(AbstractMetric): r"""GAUC (also known as Grouped Area Under Curve) is used to evaluate the two-class model, referring to the area under the ROC curve grouped by user. We weighted the index of each user :math:`u` by the number of positive @@ -321,6 +328,7 @@ class AUC(LossMetric): :math:`N` denotes the total number of user-item interactions. :math:`rank_i` denotes the descending rank of the i-th positive item. """ + def __init__(self, config): super().__init__(config) @@ -357,6 +365,7 @@ def metric_info(self, preds, trues): # Loss-based Metrics + class MAE(LossMetric): r"""MAE_ (also known as Mean Absolute Error regression loss) is used to evaluate the difference between the score predicted by the model and the actual behavior of the user. diff --git a/recbole/evaluator/register.py b/recbole/evaluator/register.py index fc56c380d..fbf041878 100644 --- a/recbole/evaluator/register.py +++ b/recbole/evaluator/register.py @@ -39,8 +39,9 @@ def cluster_info(module_name): """ smaller_m = [] m_dict, m_info, m_types = {}, {}, {} - metric_class = inspect.getmembers(sys.modules[module_name], - lambda x: inspect.isclass(x) and x.__module__ == module_name) + metric_class = inspect.getmembers( + sys.modules[module_name], lambda x: inspect.isclass(x) and x.__module__ == module_name + ) for name, metric_cls in metric_class: name = name.lower() m_dict[name] = metric_cls @@ -66,6 +67,7 @@ class Register(object): It is a member of DataCollector. The DataCollector collect the resource that need for Evaluator under the guidance of Register """ + def __init__(self, config): self.config = config @@ -88,4 +90,3 @@ def need(self, key: str): if hasattr(self, key): return getattr(self, key) return False - diff --git a/recbole/model/abstract_recommender.py b/recbole/model/abstract_recommender.py index 677ddb925..2aaae77a6 100644 --- a/recbole/model/abstract_recommender.py +++ b/recbole/model/abstract_recommender.py @@ -169,11 +169,15 @@ class ContextRecommender(AbstractRecommender): def __init__(self, config, dataset): super(ContextRecommender, self).__init__() - self.field_names = dataset.fields(source=[ - FeatureSource.INTERACTION, - FeatureSource.USER, FeatureSource.USER_ID, - FeatureSource.ITEM, FeatureSource.ITEM_ID, - ]) + self.field_names = dataset.fields( + source=[ + FeatureSource.INTERACTION, + FeatureSource.USER, + FeatureSource.USER_ID, + FeatureSource.ITEM, + FeatureSource.ITEM_ID, + ] + ) self.LABEL = config['LABEL_FIELD'] self.embedding_size = config['embedding_size'] self.device = config['device'] diff --git a/recbole/model/general_recommender/macridvae.py b/recbole/model/general_recommender/macridvae.py index eadc8a359..f57fa0812 100644 --- a/recbole/model/general_recommender/macridvae.py +++ b/recbole/model/general_recommender/macridvae.py @@ -4,7 +4,7 @@ # @Email : gyihong@hotmail.com # UPDATE -# @Time : 2021/6/30, +# @Time : 2021/6/30, # @Author : Xingyu Pan # @email : xy_pan@foxmail.com diff --git a/recbole/model/layers.py b/recbole/model/layers.py index 76ae95207..c09fb6758 100644 --- a/recbole/model/layers.py +++ b/recbole/model/layers.py @@ -918,11 +918,15 @@ class FMFirstOrderLinear(nn.Module): def __init__(self, config, dataset, output_dim=1): super(FMFirstOrderLinear, self).__init__() - self.field_names = dataset.fields(source=[ - FeatureSource.INTERACTION, - FeatureSource.USER, FeatureSource.USER_ID, - FeatureSource.ITEM, FeatureSource.ITEM_ID, - ]) + self.field_names = dataset.fields( + source=[ + FeatureSource.INTERACTION, + FeatureSource.USER, + FeatureSource.USER_ID, + FeatureSource.ITEM, + FeatureSource.ITEM_ID, + ] + ) self.LABEL = config['LABEL_FIELD'] self.device = config['device'] self.token_field_names = [] diff --git a/recbole/sampler/sampler.py b/recbole/sampler/sampler.py index 4ba338d48..805226b3d 100644 --- a/recbole/sampler/sampler.py +++ b/recbole/sampler/sampler.py @@ -20,18 +20,16 @@ import torch from collections import Counter + class AbstractSampler(object): """:class:`AbstractSampler` is a abstract class, all sampler should inherit from it. This sampler supports returning a certain number of random value_ids according to the input key_id, and it also supports to prohibit - certain key-value pairs by setting used_ids. Besides, in order to improve efficiency, we use :attr:`random_pr` - to move around the :attr:`random_list` to generate random numbers, so we need to implement the - :meth:`get_random_list` method in the subclass. + certain key-value pairs by setting used_ids. Args: distribution (str): The string of distribution, which is used for subclass. Attributes: - random_list (list or numpy.ndarray): The shuffled result of :meth:`get_random_list`. used_ids (numpy.ndarray): The result of :meth:`get_used_ids`. """ @@ -85,18 +83,18 @@ def _build_alias_table(self): large_q.append(i) elif self.prob[i] < 1: small_q.append(i) - - while(len(large_q)!=0 and len(small_q)!=0): + + while len(large_q) != 0 and len(small_q) != 0: l = large_q.pop(0) s = small_q.pop(0) - self.alias[s] = l + self.alias[s] = l self.prob[l] = self.prob[l] - (1 - self.prob[s]) if self.prob[l] < 1: small_q.append(l) elif self.prob[l] > 1: large_q.append(l) - def _pop_sampling(self, sample_num): + def _pop_sampling(self, sample_num): """Sample [sample_num] items in the popularity-biased distribution. Args: @@ -116,7 +114,7 @@ def _pop_sampling(self, sample_num): final_random_list.append(keys[idx]) else: final_random_list.append(self.alias[keys[idx]]) - + return np.array(final_random_list) def sampling(self, sample_num): @@ -128,7 +126,7 @@ def sampling(self, sample_num): Returns: sample_list (np.array): a list of samples and the len is [sample_num]. """ - if self.distribution =='uniform': + if self.distribution == 'uniform': return self._uni_sampling(sample_num) elif self.distribution == 'popularity': return self._pop_sampling(sample_num) @@ -165,7 +163,7 @@ def sample_by_key_ids(self, key_ids, num): value_ids = self.sampling(total_num) check_list = np.arange(total_num)[np.isin(value_ids, used)] while len(check_list) > 0: - value_ids[check_list] = value = self.sampling(len(check_list)) + value_ids[check_list] = value = self.sampling(len(check_list)) mask = np.isin(value, used) check_list = check_list[mask] else: @@ -215,13 +213,12 @@ def __init__(self, phases, datasets, distribution='uniform'): super().__init__(distribution=distribution) - def _get_candidates_list(self): candidates_list = [] for dataset in self.datasets: candidates_list.extend(dataset.inter_feat[self.iid_field].numpy()) return candidates_list - + def _uni_sampling(self, sample_num): return np.random.randint(1, self.item_num, sample_num) @@ -386,18 +383,6 @@ def _uni_sampling(self, sample_num): def _get_candidates_list(self): return list(self.dataset.inter_feat[self.iid_field].numpy()) - def get_random_list(self): - """ - Returns: - numpy.ndarray or list: Random list of item_id. - """ - if self.distribution == 'uniform': - return np.arange(1, self.item_num) - elif self.distribution == 'popularity': - return self.dataset.inter_feat[self.iid_field].numpy() - else: - raise NotImplementedError(f'Distribution [{self.distribution}] has not been implemented.') - def get_used_ids(self): """ Returns: diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index b3316d52d..0b0a3cb7c 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -275,14 +275,12 @@ def _add_hparam_to_tensorboard(self, best_valid_result): # unrecorded parameter unrecorded_parameter = { parameter - for parameters in self.config.parameters.values() - for parameter in parameters + for parameters in self.config.parameters.values() for parameter in parameters }.union({'model', 'dataset', 'config_files', 'device'}) # other model-specific hparam hparam_dict.update({ para: val - for para, val in self.config.final_config_dict.items() - if para not in unrecorded_parameter + for para, val in self.config.final_config_dict.items() if para not in unrecorded_parameter }) for k in hparam_dict: if hparam_dict[k] is not None and not isinstance(hparam_dict[k], (bool, str, float, int)): @@ -896,8 +894,10 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre elif self.model.train_stage == 'finetune': return super().fit(train_data, valid_data, verbose, saved, show_progress, callback_fn) else: - raise ValueError("Please make sure that the 'train_stage' is " - "'actor_pretrain', 'critic_pretrain' or 'finetune'!") + raise ValueError( + "Please make sure that the 'train_stage' is " + "'actor_pretrain', 'critic_pretrain' or 'finetune'!" + ) class lightgbmTrainer(DecisionTreeTrainer): diff --git a/recbole/utils/__init__.py b/recbole/utils/__init__.py index 062f250ed..5b6ef50b3 100644 --- a/recbole/utils/__init__.py +++ b/recbole/utils/__init__.py @@ -6,7 +6,7 @@ __all__ = [ 'init_logger', 'get_local_time', 'ensure_dir', 'get_model', 'get_trainer', 'early_stopping', - 'calculate_valid_score', 'dict2str', 'Enum', 'ModelType', 'KGDataLoaderState', 'EvaluatorType', - 'InputType', 'FeatureType', 'FeatureSource', 'init_seed', 'general_arguments', 'training_arguments', - 'evaluation_arguments', 'dataset_arguments', 'get_tensorboard', 'set_color', 'get_gpu_usage' + 'calculate_valid_score', 'dict2str', 'Enum', 'ModelType', 'KGDataLoaderState', 'EvaluatorType', 'InputType', + 'FeatureType', 'FeatureSource', 'init_seed', 'general_arguments', 'training_arguments', 'evaluation_arguments', + 'dataset_arguments', 'get_tensorboard', 'set_color', 'get_gpu_usage' ] diff --git a/recbole/utils/logger.py b/recbole/utils/logger.py index 8f4f70355..d5d5d52ca 100644 --- a/recbole/utils/logger.py +++ b/recbole/utils/logger.py @@ -30,6 +30,7 @@ class RemoveColorFilter(logging.Filter): + def filter(self, record): if record: ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') diff --git a/recbole/utils/url.py b/recbole/utils/url.py index 52553b19b..cd7b99f8b 100644 --- a/recbole/utils/url.py +++ b/recbole/utils/url.py @@ -14,13 +14,12 @@ from tqdm import tqdm - GBFACTOR = float(1 << 30) def decide_download(url): d = ur.urlopen(url) - size = int(d.info()['Content-Length'])/GBFACTOR + size = int(d.info()['Content-Length']) / GBFACTOR ### confirm if larger than 1GB if size > 1: @@ -60,8 +59,8 @@ def download_url(url, folder): size = int(data.info()['Content-Length']) - chunk_size = 1024*1024 - num_iter = int(size/chunk_size) + 2 + chunk_size = 1024 * 1024 + num_iter = int(size / chunk_size) + 2 downloaded_size = 0 @@ -71,14 +70,13 @@ def download_url(url, folder): for i in pbar: chunk = data.read(chunk_size) downloaded_size += len(chunk) - pbar.set_description('Downloaded {:.2f} GB'.format(float(downloaded_size)/GBFACTOR)) + pbar.set_description('Downloaded {:.2f} GB'.format(float(downloaded_size) / GBFACTOR)) f.write(chunk) except: if os.path.exists(path): - os.remove(path) + os.remove(path) raise RuntimeError('Stopped downloading due to interruption.') - return path @@ -109,9 +107,8 @@ def rename_atomic_files(folder, old_name, new_name): if not old_name in base: continue assert suf in {'.inter', '.user', '.item'} - os.rename( - os.path.join(folder, f), - os.path.join(folder, base.replace(old_name, new_name) + suf)) + os.rename(os.path.join(folder, f), os.path.join(folder, base.replace(old_name, new_name) + suf)) + if __name__ == '__main__': pass diff --git a/recbole/utils/utils.py b/recbole/utils/utils.py index 53df843c4..02c4714df 100644 --- a/recbole/utils/utils.py +++ b/recbole/utils/utils.py @@ -233,4 +233,3 @@ def get_gpu_usage(device=None): total = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 return '{:.2f} G/{:.2f} G'.format(reserved, total) -