Skip to content

Commit

Permalink
Merge pull request #73 from hyp1231/master
Browse files Browse the repository at this point in the history
FEA: Enum Model Type
  • Loading branch information
hyp1231 authored Jul 23, 2020
2 parents 87d8f96 + d1067f2 commit ea951b3
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
5 changes: 3 additions & 2 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .dataloader import *
from config import EvalSetting
from utils import ModelType

def data_preparation(config, model, dataset):
es = EvalSetting(config)
Expand Down Expand Up @@ -35,7 +36,7 @@ def data_preparation(config, model, dataset):
return train_data, test_data, valid_data

def dataloader_construct(name, config, eval_setting, dataset,
dl_type='general', dl_format='pointwise',
dl_type=ModelType.GENERAL, dl_format='pointwise',
batch_size=1, shuffle=False):
if not isinstance(dataset, list):
dataset = [dataset]
Expand All @@ -50,7 +51,7 @@ def dataloader_construct(name, config, eval_setting, dataset,
print(eval_setting)
print('batch_size = {}, shuffle = {}\n'.format(batch_size, shuffle))

if dl_type == 'general':
if dl_type == ModelType.GENERAL:
DataLoader = GeneralDataLoader
else:
raise NotImplementedError('dl_type [{}] has not been implemented'.format(dl_type))
Expand Down
3 changes: 2 additions & 1 deletion model/general_recommender/bprmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

from model.abstract_recommender import AbstractRecommender
from model.loss import BPRLoss
from utils import ModelType


class BPRMF(AbstractRecommender):

def __init__(self, config, dataset):
super(BPRMF, self).__init__()

self.type = 'general'
self.type = ModelType.GENERAL

self.USER_ID = config['USER_ID_FIELD']
self.ITEM_ID = config['ITEM_ID_FIELD']
Expand Down
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .metrics import *
from .logger import Logger
from .utils import get_local_time, ensure_dir
from .enum_type import *
8 changes: 8 additions & 0 deletions utils/enum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum

class ModelType(Enum):
GENERAL = 1
SEQUENTIAL = 2
CONTEXT = 3
KNOWLEDGE = 4
SOCIAL = 5

0 comments on commit ea951b3

Please sign in to comment.