forked from aHuiWang/CIKM2020-S3Rec
-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathrun_pretrain.py
106 lines (79 loc) · 4.44 KB
/
run_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
# @Time : 2020/4/25 16:01
# @Author : Hui Wang
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, RandomSampler
import os
import argparse
from datasets import PretrainDataset
from trainers import PretrainTrainer
from models import S3RecModel
from utils import get_user_seqs_long, get_item2attribute_json, check_path, set_seed
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='./data/', type=str)
parser.add_argument('--output_dir', default='output/', type=str)
parser.add_argument('--data_name', default='Beauty', type=str)
# model args
parser.add_argument("--model_name", default='Pretrain', type=str)
parser.add_argument("--hidden_size", type=int, default=64, help="hidden size of transformer model")
parser.add_argument("--num_hidden_layers", type=int, default=2, help="number of layers")
parser.add_argument('--num_attention_heads', default=2, type=int)
parser.add_argument('--hidden_act', default="gelu", type=str) # gelu relu
parser.add_argument("--attention_probs_dropout_prob", type=float, default=0.5, help="attention dropout p")
parser.add_argument("--hidden_dropout_prob", type=float, default=0.5, help="hidden dropout p")
parser.add_argument("--initializer_range", type=float, default=0.02)
parser.add_argument('--max_seq_length', default=50, type=int)
# train args
parser.add_argument("--lr", type=float, default=0.001, help="learning rate of adam")
parser.add_argument("--batch_size", type=int, default=256, help="number of batch_size")
parser.add_argument("--epochs", type=int, default=200, help="number of epochs")
parser.add_argument("--no_cuda", action="store_true")
parser.add_argument("--log_freq", type=int, default=1, help="per epoch print res")
parser.add_argument("--seed", default=42, type=int)
# pre train args
parser.add_argument("--pre_epochs", type=int, default=300, help="number of pre_train epochs")
parser.add_argument("--pre_batch_size", type=int, default=100)
parser.add_argument("--mask_p", type=float, default=0.2, help="mask probability")
parser.add_argument("--aap_weight", type=float, default=0.2, help="aap loss weight")
parser.add_argument("--mip_weight", type=float, default=1.0, help="mip loss weight")
parser.add_argument("--map_weight", type=float, default=1.0, help="map loss weight")
parser.add_argument("--sp_weight", type=float, default=0.5, help="sp loss weight")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight_decay of adam")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam second beta value")
parser.add_argument("--gpu_id", type=str, default="0", help="gpu_id")
args = parser.parse_args()
set_seed(args.seed)
check_path(args.output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
args.cuda_condition = torch.cuda.is_available() and not args.no_cuda
args.data_file = args.data_dir + args.data_name + '.txt'
item2attribute_file = args.data_dir + args.data_name + '_item2attributes.json'
# concat all user_seq get a long sequence, from which sample neg segment for SP
user_seq, max_item, long_sequence = get_user_seqs_long(args.data_file)
item2attribute, attribute_size = get_item2attribute_json(item2attribute_file)
args.item_size = max_item + 2
args.mask_id = max_item + 1
args.attribute_size = attribute_size + 1
# save model args
args_str = f'{args.model_name}-{args.data_name}'
args.log_file = os.path.join(args.output_dir, args_str + '.txt')
print(args)
with open(args.log_file, 'a') as f:
f.write(str(args) + '\n')
args.item2attribute = item2attribute
model = S3RecModel(args=args)
trainer = PretrainTrainer(model, None, None, None, args)
for epoch in range(args.pre_epochs):
pretrain_dataset = PretrainDataset(args, user_seq, long_sequence)
pretrain_sampler = RandomSampler(pretrain_dataset)
pretrain_dataloader = DataLoader(pretrain_dataset, sampler=pretrain_sampler, batch_size=args.pre_batch_size)
trainer.pretrain(epoch, pretrain_dataloader)
if (epoch+1) % 10 == 0:
ckp = f'{args.data_name}-epochs-{epoch+1}.pt'
checkpoint_path = os.path.join(args.output_dir, ckp)
trainer.save(checkpoint_path)
main()