From 09f9460700026bee652fb02a6260c8122d7b54d6 Mon Sep 17 00:00:00 2001 From: XXXU Date: Wed, 22 Jan 2020 13:28:05 +0800 Subject: [PATCH] modularization --- common/config/__init__.py | 1 + common/dataset/__init__.py | 1 + main.py | 17 ++++++++--------- modules/recommender/__init__.py | 1 + modules/sampler/__init__.py | 1 + 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/common/config/__init__.py b/common/config/__init__.py index e69de29..ab12dde 100644 --- a/common/config/__init__.py +++ b/common/config/__init__.py @@ -0,0 +1 @@ +from .parser import parse_args diff --git a/common/dataset/__init__.py b/common/dataset/__init__.py index e69de29..82c9e32 100644 --- a/common/dataset/__init__.py +++ b/common/dataset/__init__.py @@ -0,0 +1 @@ +from .preprocess import CKGData diff --git a/main.py b/main.py index 531c72c..9d7d440 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,23 @@ import os import random -from time import time -from pathlib import Path import torch import numpy as np +from time import time from tqdm import tqdm from copy import deepcopy +from pathlib import Path from prettytable import PrettyTable from common.test import test_v2 from common.utils import early_stopping, print_dict -from common.config.parser import parse_args - +from common.config import parse_args +from common.dataset import CKGData from common.dataset.build import build_loader -from common.dataset.preprocess import CKGData -from modules.recommender.MF import MF -from modules.sampler.kgpolicy import KGPolicy +from modules.sampler import KGPolicy +from modules.recommender import MF def train_one_epoch( @@ -50,9 +49,9 @@ def train_one_epoch( """Train recommender using negtive item provided by sampler""" recommender_optim.zero_grad() - users = batch_data["u_id"] neg = batch_data["neg_i_id"] pos = batch_data["pos_i_id"] + users = batch_data["u_id"] selected_neg_items_list, _ = sampler(batch_data, adj_matrix, edge_matrix) selected_neg_items = selected_neg_items_list[-1, :] @@ -98,8 +97,8 @@ def train_one_epoch( """record loss in an epoch""" loss += loss_batch - base_loss += base_loss_batch reg_loss += reg_loss_batch + base_loss += base_loss_batch avg_reward = epoch_reward / num_batch train_res = PrettyTable() diff --git a/modules/recommender/__init__.py b/modules/recommender/__init__.py index e69de29..3f6f729 100644 --- a/modules/recommender/__init__.py +++ b/modules/recommender/__init__.py @@ -0,0 +1 @@ +from .MF import MF diff --git a/modules/sampler/__init__.py b/modules/sampler/__init__.py index e69de29..72b6287 100644 --- a/modules/sampler/__init__.py +++ b/modules/sampler/__init__.py @@ -0,0 +1 @@ +from .kgpolicy import KGPolicy