From d1067f2706141bf4ff4e6276f1ff8997d2d53ae0 Mon Sep 17 00:00:00 2001 From: hyp1231 Date: Thu, 23 Jul 2020 13:55:55 +0800 Subject: [PATCH] FEA: Enum Model Type --- data/utils.py | 5 +++-- model/general_recommender/bprmf.py | 3 ++- utils/__init__.py | 1 + utils/enum_type.py | 8 ++++++++ 4 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 utils/enum_type.py diff --git a/data/utils.py b/data/utils.py index a23eefce1..5de9d4b30 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,5 +1,6 @@ from .dataloader import * from config import EvalSetting +from utils import ModelType def data_preparation(config, model, dataset): es = EvalSetting(config) @@ -35,7 +36,7 @@ def data_preparation(config, model, dataset): return train_data, test_data, valid_data def dataloader_construct(name, config, eval_setting, dataset, - dl_type='general', dl_format='pointwise', + dl_type=ModelType.GENERAL, dl_format='pointwise', batch_size=1, shuffle=False): if not isinstance(dataset, list): dataset = [dataset] @@ -50,7 +51,7 @@ def dataloader_construct(name, config, eval_setting, dataset, print(eval_setting) print('batch_size = {}, shuffle = {}\n'.format(batch_size, shuffle)) - if dl_type == 'general': + if dl_type == ModelType.GENERAL: DataLoader = GeneralDataLoader else: raise NotImplementedError('dl_type [{}] has not been implemented'.format(dl_type)) diff --git a/model/general_recommender/bprmf.py b/model/general_recommender/bprmf.py index 3905acccc..0614f9118 100644 --- a/model/general_recommender/bprmf.py +++ b/model/general_recommender/bprmf.py @@ -15,6 +15,7 @@ from model.abstract_recommender import AbstractRecommender from model.loss import BPRLoss +from utils import ModelType class BPRMF(AbstractRecommender): @@ -22,7 +23,7 @@ class BPRMF(AbstractRecommender): def __init__(self, config, dataset): super(BPRMF, self).__init__() - self.type = 'general' + self.type = ModelType.GENERAL self.USER_ID = config['USER_ID_FIELD'] self.ITEM_ID = config['ITEM_ID_FIELD'] diff --git a/utils/__init__.py b/utils/__init__.py index 67235de1a..62a35414e 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,4 @@ from .metrics import * from .logger import Logger from .utils import get_local_time, ensure_dir +from .enum_type import * diff --git a/utils/enum_type.py b/utils/enum_type.py new file mode 100644 index 000000000..67dcc6a90 --- /dev/null +++ b/utils/enum_type.py @@ -0,0 +1,8 @@ +from enum import Enum + +class ModelType(Enum): + GENERAL = 1 + SEQUENTIAL = 2 + CONTEXT = 3 + KNOWLEDGE = 4 + SOCIAL = 5