diff --git a/common/config/__init__.py b/common/config/__init__.py index e69de29..ab12dde 100644 --- a/common/config/__init__.py +++ b/common/config/__init__.py @@ -0,0 +1 @@ +from .parser import parse_args diff --git a/common/dataset/__init__.py b/common/dataset/__init__.py index e69de29..82c9e32 100644 --- a/common/dataset/__init__.py +++ b/common/dataset/__init__.py @@ -0,0 +1 @@ +from .preprocess import CKGData diff --git a/main.py b/main.py index 531c72c..9d7d440 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,23 @@ import os import random -from time import time -from pathlib import Path import torch import numpy as np +from time import time from tqdm import tqdm from copy import deepcopy +from pathlib import Path from prettytable import PrettyTable from common.test import test_v2 from common.utils import early_stopping, print_dict -from common.config.parser import parse_args - +from common.config import parse_args +from common.dataset import CKGData from common.dataset.build import build_loader -from common.dataset.preprocess import CKGData -from modules.recommender.MF import MF -from modules.sampler.kgpolicy import KGPolicy +from modules.sampler import KGPolicy +from modules.recommender import MF def train_one_epoch( @@ -50,9 +49,9 @@ def train_one_epoch( """Train recommender using negtive item provided by sampler""" recommender_optim.zero_grad() - users = batch_data["u_id"] neg = batch_data["neg_i_id"] pos = batch_data["pos_i_id"] + users = batch_data["u_id"] selected_neg_items_list, _ = sampler(batch_data, adj_matrix, edge_matrix) selected_neg_items = selected_neg_items_list[-1, :] @@ -98,8 +97,8 @@ def train_one_epoch( """record loss in an epoch""" loss += loss_batch - base_loss += base_loss_batch reg_loss += reg_loss_batch + base_loss += base_loss_batch avg_reward = epoch_reward / num_batch train_res = PrettyTable() diff --git a/modules/recommender/__init__.py b/modules/recommender/__init__.py index e69de29..3f6f729 100644 --- a/modules/recommender/__init__.py +++ b/modules/recommender/__init__.py @@ -0,0 +1 @@ +from .MF import MF diff --git a/modules/sampler/__init__.py b/modules/sampler/__init__.py index e69de29..72b6287 100644 --- a/modules/sampler/__init__.py +++ b/modules/sampler/__init__.py @@ -0,0 +1 @@ +from .kgpolicy import KGPolicy