-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_cnn_lstm.py
116 lines (105 loc) · 5.04 KB
/
train_cnn_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import Queue
import threading
from multiprocessing.pool import ThreadPool
import numpy as np
import theano.tensor as T
dtype = T.config.floatX
import argparse
import helper.config as config
import model.model_provider as model_provider
import helper.dt_utils as du
import helper.utils as u
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
def train_rnn(params):
rng = RandomStreams(seed=1234)
(X_train,Y_train,S_Train_list,F_list_train,G_list_train,X_test,Y_test,S_Test_list,F_list_test,G_list_test)=du.load_pose(params)
params["len_train"]=Y_train.shape[0]*Y_train.shape[1]
params["len_test"]=Y_test.shape[0]*Y_test.shape[1]
u.start_log(params)
index_train_list,S_Train_list=du.get_batch_indexes(S_Train_list) #This will prepare bacth indexes
index_test_list,S_Test_list=du.get_batch_indexes(S_Test_list)
batch_size=params['batch_size']
n_train_batches = len(index_train_list)
n_train_batches /= batch_size
n_test_batches = len(index_test_list)
n_test_batches /= batch_size
nb_epochs=params['n_epochs']
print("Batch size: %i, train batch size: %i, test batch size: %i"%(batch_size,n_train_batches,n_test_batches))
u.log_write("Model build started",params)
if params['run_mode']==1:
model= model_provider.get_model_pretrained(params,rng)
u.log_write("Pretrained loaded: %s"%(params['mfile']),params)
else:
model= model_provider.get_model(params,rng)
u.log_write("Number of parameters: %s"%(model.n_param),params)
train_errors = np.ndarray(nb_epochs)
u.log_write("Training started",params)
val_counter=0
best_loss=1000
for epoch_counter in range(nb_epochs):
batch_loss = 0.
H=C=np.zeros(shape=(batch_size,params['n_hidden']), dtype=dtype) # initial hidden state
sid=0
is_train=1
x=[]
y=[]
for minibatch_index in range(n_train_batches):
if(minibatch_index==0):
(sid,H,C,x,y)=du.prepare_cnn_lstm_batch(index_train_list, minibatch_index, batch_size, S_Train_list, sid, H, C, F_list_train, params, Y_train, X_train)
pool = ThreadPool(processes=2)
async_t = pool.apply_async(model.train, (x, y,is_train,H,C))
async_b = pool.apply_async(du.prepare_cnn_lstm_batch, (index_train_list, minibatch_index, batch_size, S_Train_list, sid, H, C, F_list_train, params, Y_train, X_train))
pool.close()
pool.join()
(loss,H,C) = async_t.get() # get the return value from your function.
x=[]
y=[]
(sid,H,C,x,y) = async_b.get() # get the return value from your function.
if(minibatch_index==n_train_batches-1):
loss,H,C= model.train(x, y,is_train,H,C)
batch_loss += loss
if params['shufle_data']==1:
X_train,Y_train=du.shuffle_in_unison_inplace(X_train,Y_train)
train_errors[epoch_counter] = batch_loss
batch_loss/=n_train_batches
s='TRAIN--> epoch %i | error %f'%(epoch_counter, batch_loss)
u.log_write(s,params)
if(epoch_counter%1==0):
print("Model testing")
batch_loss3d = []
H=C=np.zeros(shape=(batch_size,params['n_hidden']), dtype=dtype) # resetting initial state, since seq change
sid=0
for minibatch_index in range(n_test_batches):
if(minibatch_index==0):
(sid,H,C,x,y)=du.prepare_cnn_lstm_batch(index_test_list, minibatch_index, batch_size, S_Test_list, sid, H, C, F_list_test, params, Y_test, X_test)
pool = ThreadPool(processes=2)
async_t = pool.apply_async(model.predictions, (x,is_train,H,C))
async_b = pool.apply_async(du.prepare_cnn_lstm_batch, (index_test_list, minibatch_index, batch_size, S_Test_list, sid, H, C, F_list_test, params, Y_test, X_test))
pool.close()
pool.join()
(pred,H,C) = async_t.get() # get the return value from your function.
loss3d =u.get_loss(params,y,pred)
batch_loss3d.append(loss3d)
(sid,H,C,x,y) = async_b.get() # get the return value from your function.
if(minibatch_index==n_train_batches-1):
pred,H,C= model.predictions(x,is_train,H,C)
loss3d =u.get_loss(params,y,pred)
batch_loss3d.append(loss3d)
batch_loss3d=np.nanmean(batch_loss3d)
if(batch_loss3d<best_loss):
best_loss=batch_loss3d
ext=str(epoch_counter)+"_"+str(batch_loss3d)+"_best.p"
u.write_params(model.params,params,ext)
else:
ext=str(val_counter%2)+".p"
u.write_params(model.params,params,ext)
val_counter+=1
s ='VAL--> epoch %i | error %f, %f'%(val_counter,batch_loss3d,n_test_batches)
u.log_write(s,params)
params= config.get_params()
parser = argparse.ArgumentParser(description='Training the module')
parser.add_argument('-m','--model',help='Model: lstm, erd current('+params["model"]+')',default=params["model"])
args = vars(parser.parse_args())
params["model"]=args["model"]
params=config.update_params(params)
train_rnn(params)