Skip to content

Commit

Permalink
modularization
Browse files Browse the repository at this point in the history
  • Loading branch information
XU-YaoKun committed Jan 22, 2020
1 parent 3b41fac commit 09f9460
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 9 deletions.
1 change: 1 addition & 0 deletions common/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parser import parse_args
1 change: 1 addition & 0 deletions common/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .preprocess import CKGData
17 changes: 8 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import os
import random
from time import time
from pathlib import Path

import torch
import numpy as np

from time import time
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
from prettytable import PrettyTable

from common.test import test_v2
from common.utils import early_stopping, print_dict
from common.config.parser import parse_args

from common.config import parse_args
from common.dataset import CKGData
from common.dataset.build import build_loader
from common.dataset.preprocess import CKGData

from modules.recommender.MF import MF
from modules.sampler.kgpolicy import KGPolicy
from modules.sampler import KGPolicy
from modules.recommender import MF


def train_one_epoch(
Expand Down Expand Up @@ -50,9 +49,9 @@ def train_one_epoch(
"""Train recommender using negtive item provided by sampler"""
recommender_optim.zero_grad()

users = batch_data["u_id"]
neg = batch_data["neg_i_id"]
pos = batch_data["pos_i_id"]
users = batch_data["u_id"]

selected_neg_items_list, _ = sampler(batch_data, adj_matrix, edge_matrix)
selected_neg_items = selected_neg_items_list[-1, :]
Expand Down Expand Up @@ -98,8 +97,8 @@ def train_one_epoch(

"""record loss in an epoch"""
loss += loss_batch
base_loss += base_loss_batch
reg_loss += reg_loss_batch
base_loss += base_loss_batch

avg_reward = epoch_reward / num_batch
train_res = PrettyTable()
Expand Down
1 change: 1 addition & 0 deletions modules/recommender/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .MF import MF
1 change: 1 addition & 0 deletions modules/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .kgpolicy import KGPolicy

0 comments on commit 09f9460

Please sign in to comment.