forked from xiaojunxu/SQLNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
170 lines (161 loc) · 7.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import json
import torch
from sqlnet.utils import *
from sqlnet.model.seq2sql import Seq2SQL
from sqlnet.model.sqlnet import SQLNet
import numpy as np
import datetime
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--toy', action='store_true',
help='If set, use small data; used for fast debugging.')
parser.add_argument('--suffix', type=str, default='',
help='The suffix at the end of saved model name.')
parser.add_argument('--ca', action='store_true',
help='Use conditional attention.')
parser.add_argument('--dataset', type=int, default=0,
help='0: original dataset, 1: re-split dataset')
parser.add_argument('--rl', action='store_true',
help='Use RL for Seq2SQL(requires pretrained model).')
parser.add_argument('--baseline', action='store_true',
help='If set, then train Seq2SQL model; default is SQLNet model.')
parser.add_argument('--train_emb', action='store_true',
help='Train word embedding for SQLNet(requires pretrained model).')
args = parser.parse_args()
N_word=300
B_word=42
if args.toy:
USE_SMALL=True
GPU=True
BATCH_SIZE=15
else:
USE_SMALL=False
GPU=True
BATCH_SIZE=64
TRAIN_ENTRY=(True, True, True) # (AGG, SEL, COND)
TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY
learning_rate = 1e-4 if args.rl else 1e-3
sql_data, table_data, val_sql_data, val_table_data, \
test_sql_data, test_table_data, \
TRAIN_DB, DEV_DB, TEST_DB = load_dataset(
args.dataset, use_small=USE_SMALL)
word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
load_used=args.train_emb, use_small=USE_SMALL)
if args.baseline:
model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU,
trainable_emb = args.train_emb)
assert not args.train_emb, "Seq2SQL can\'t train embedding."
else:
model = SQLNet(word_emb, N_word=N_word, use_ca=args.ca,
gpu=GPU, trainable_emb = args.train_emb)
assert not args.rl, "SQLNet can\'t do reinforcement learning."
optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate, weight_decay = 0)
if args.train_emb:
agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args)
else:
agg_m, sel_m, cond_m = best_model_name(args)
if args.rl or args.train_emb: # Load pretrained model.
agg_lm, sel_lm, cond_lm = best_model_name(args, for_load=True)
print "Loading from %s"%agg_lm
model.agg_pred.load_state_dict(torch.load(agg_lm))
print "Loading from %s"%sel_lm
model.sel_pred.load_state_dict(torch.load(sel_lm))
print "Loading from %s"%cond_lm
model.cond_pred.load_state_dict(torch.load(cond_lm))
if args.rl:
best_acc = 0.0
best_idx = -1
print "Init dev acc_qm: %s\n breakdown on (agg, sel, where): %s"% \
epoch_acc(model, BATCH_SIZE, val_sql_data,\
val_table_data, TRAIN_ENTRY)
print "Init dev acc_ex: %s"%epoch_exec_acc(
model, BATCH_SIZE, val_sql_data, val_table_data, DEV_DB)
torch.save(model.cond_pred.state_dict(), cond_m)
for i in range(100):
print 'Epoch %d @ %s'%(i+1, datetime.datetime.now())
print ' Avg reward = %s'%epoch_reinforce_train(
model, optimizer, BATCH_SIZE, sql_data, table_data, TRAIN_DB)
print ' dev acc_qm: %s\n breakdown result: %s'% epoch_acc(
model, BATCH_SIZE, val_sql_data, val_table_data, TRAIN_ENTRY)
exec_acc = epoch_exec_acc(
model, BATCH_SIZE, val_sql_data, val_table_data, DEV_DB)
print ' dev acc_ex: %s', exec_acc
if exec_acc[0] > best_acc:
best_acc = exec_acc[0]
best_idx = i+1
torch.save(model.cond_pred.state_dict(),
'saved_model/epoch%d.cond_model%s'%(i+1, args.suffix))
torch.save(model.cond_pred.state_dict(), cond_m)
print ' Best exec acc = %s, on epoch %s'%(best_acc, best_idx)
else:
init_acc = epoch_acc(model, BATCH_SIZE,
val_sql_data, val_table_data, TRAIN_ENTRY)
best_agg_acc = init_acc[1][0]
best_agg_idx = 0
best_sel_acc = init_acc[1][1]
best_sel_idx = 0
best_cond_acc = init_acc[1][2]
best_cond_idx = 0
print 'Init dev acc_qm: %s\n breakdown on (agg, sel, where): %s'%\
init_acc
if TRAIN_AGG:
torch.save(model.agg_pred.state_dict(), agg_m)
if args.train_emb:
torch.save(model.agg_embed_layer.state_dict(), agg_e)
if TRAIN_SEL:
torch.save(model.sel_pred.state_dict(), sel_m)
if args.train_emb:
torch.save(model.sel_embed_layer.state_dict(), sel_e)
if TRAIN_COND:
torch.save(model.cond_pred.state_dict(), cond_m)
if args.train_emb:
torch.save(model.cond_embed_layer.state_dict(), cond_e)
for i in range(100):
print 'Epoch %d @ %s'%(i+1, datetime.datetime.now())
print ' Loss = %s'%epoch_train(
model, optimizer, BATCH_SIZE,
sql_data, table_data, TRAIN_ENTRY)
print ' Train acc_qm: %s\n breakdown result: %s'%epoch_acc(
model, BATCH_SIZE, sql_data, table_data, TRAIN_ENTRY)
#val_acc = epoch_token_acc(model, BATCH_SIZE, val_sql_data, val_table_data, TRAIN_ENTRY)
val_acc = epoch_acc(model,
BATCH_SIZE, val_sql_data, val_table_data, TRAIN_ENTRY)
print ' Dev acc_qm: %s\n breakdown result: %s'%val_acc
if TRAIN_AGG:
if val_acc[1][0] > best_agg_acc:
best_agg_acc = val_acc[1][0]
best_agg_idx = i+1
torch.save(model.agg_pred.state_dict(),
'saved_model/epoch%d.agg_model%s'%(i+1, args.suffix))
torch.save(model.agg_pred.state_dict(), agg_m)
if args.train_emb:
torch.save(model.agg_embed_layer.state_dict(),
'saved_model/epoch%d.agg_embed%s'%(i+1, args.suffix))
torch.save(model.agg_embed_layer.state_dict(), agg_e)
if TRAIN_SEL:
if val_acc[1][1] > best_sel_acc:
best_sel_acc = val_acc[1][1]
best_sel_idx = i+1
torch.save(model.sel_pred.state_dict(),
'saved_model/epoch%d.sel_model%s'%(i+1, args.suffix))
torch.save(model.sel_pred.state_dict(), sel_m)
if args.train_emb:
torch.save(model.sel_embed_layer.state_dict(),
'saved_model/epoch%d.sel_embed%s'%(i+1, args.suffix))
torch.save(model.sel_embed_layer.state_dict(), sel_e)
if TRAIN_COND:
if val_acc[1][2] > best_cond_acc:
best_cond_acc = val_acc[1][2]
best_cond_idx = i+1
torch.save(model.cond_pred.state_dict(),
'saved_model/epoch%d.cond_model%s'%(i+1, args.suffix))
torch.save(model.cond_pred.state_dict(), cond_m)
if args.train_emb:
torch.save(model.cond_embed_layer.state_dict(),
'saved_model/epoch%d.cond_embed%s'%(i+1, args.suffix))
torch.save(model.cond_embed_layer.state_dict(), cond_e)
print ' Best val acc = %s, on epoch %s individually'%(
(best_agg_acc, best_sel_acc, best_cond_acc),
(best_agg_idx, best_sel_idx, best_cond_idx))