-
Notifications
You must be signed in to change notification settings - Fork 21
/
octAttention.py
233 lines (188 loc) · 9.06 KB
/
octAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
'''
Author: [email protected]
Date: 2021-09-20 08:06:11
LastEditTime: 2021-09-20 23:53:24
LastEditors: fcy
Description: the training file
see networkTool.py to set up the parameters
will generate training log file loss.log and checkpoint in folder 'expName'
FilePath: /compression/octAttention.py
All rights reserved.
'''
import math
import torch
import torch.nn as nn
import os
import datetime
from networkTool import *
from torch.utils.tensorboard import SummaryWriter
from attentionModel import TransformerLayer,TransformerModule
##########################
ntokens = 255 # the size of vocabulary
ninp = 4*(128+4+6) # embedding dimension
nhid = 300 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 3 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 4 # the number of heads in the multiheadattention models
dropout = 0 # the dropout value
batchSize = 32
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerModule(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, 128)
self.encoder1 = nn.Embedding(MAX_OCTREE_LEVEL+1, 6)
self.encoder2 = nn.Embedding(9, 4)
self.ninp = ninp
self.act = nn.ReLU()
self.decoder0 = nn.Linear(ninp, ninp)
self.decoder1 = nn.Linear(ninp, ntoken)
self.init_weights()
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data = nn.init.xavier_normal_(self.encoder.weight.data )
self.decoder0.bias.data.zero_()
self.decoder0.weight.data= nn.init.xavier_normal_(self.decoder0.weight.data )
self.decoder1.bias.data.zero_()
self.decoder1.weight.data = nn.init.xavier_normal_(self.decoder1.weight.data )
def forward(self, src, src_mask, dataFeat):
bptt = src.shape[0]
batch = src.shape[1]
oct = src[:,:,:,0] #oct[bptt,batchsize,FeatDim(levels)] [0~254]
level = src[:,:,:,1] # [0~12] 0 for padding data
octant = src[:,:,:,2] # [0~8] 0 for padding data
# assert oct.min()>=0 and oct.max()<255
# assert level.min()>=0 and level.max()<=12
# assert octant.min()>=0 and octant.max()<=8
level -= torch.clip(level[:,:,-1:] - 10,0,None)# the max level in traning dataset is 10
torch.clip_(level,0,MAX_OCTREE_LEVEL)
aOct = self.encoder(oct.long()) #a[bptt,batchsize,FeatDim(levels),EmbeddingDim]
aLevel = self.encoder1(level.long())
aOctant = self.encoder2(octant.long())
a = torch.cat((aOct,aLevel,aOctant),3)
a = a.reshape((bptt,batch,-1))
# src = self.ancestor_attention(a)
src = a.reshape((bptt,a.shape[1],self.ninp))* math.sqrt(self.ninp)
# src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder1(self.act(self.decoder0(output)))
return output
######################################################################
# ``PositionalEncoding`` module
#
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
######################################################################
# Functions to generate input and target sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len].clone()
target = source[i+1:i+1+seq_len,:,-1,0].reshape(-1)
data[:,:,0:-1,:] = source[i+1:i+seq_len+1,:,0:-1,:] # this moves the feat(octant,level) of current node to lastrow,
data[:,:,-1,1:3] = source[i+1:i+seq_len+1,:,-1,1:3]# which will be used as known feat
return data[:,:,-levelNumK:,:], (target).long(),[]
######################################################################
# Run the model
# -------------
#
model = TransformerModel(ntokens, ninp, nhead, nhid, nlayers, dropout).to(device)
if __name__=="__main__":
import dataset
import torch.utils.data as data
import time
import os
epochs = 8 # The number of epochs
best_model = None
batch_size = 128
TreePoint = bptt*16
train_set = dataset.DataFolder(root=trainDataRoot, TreePoint=TreePoint,transform=None,dataLenPerFile= 391563.61670395226) # you should run 'dataLenPerFile' in dataset.py to get this num (17456051.4)
train_loader = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=False, num_workers=4,drop_last=True) # will load TreePoint*batch_size at one time
# loger
if not os.path.exists(checkpointPath):
os.makedirs(checkpointPath)
printl = CPrintl(expName+'/loss.log')
writer = SummaryWriter('./log/'+expName)
printl(datetime.datetime.now().strftime('\r\n%Y-%m-%d:%H:%M:%S'))
# model_structure(model,printl)
printl(expComment+' Pid: '+str(os.getpid()))
log_interval = int(batch_size*TreePoint/batchSize/bptt)
# learning
criterion = nn.CrossEntropyLoss()
lr = 1e-3 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
best_val_loss = float("inf")
idloss = 0
# reload
saveDic = None
# saveDic = reload(100030,checkpointPath)
if saveDic:
scheduler.last_epoch = saveDic['epoch'] - 1
idloss = saveDic['idloss']
best_val_loss = saveDic['best_val_loss']
model.load_state_dict(saveDic['encoder'])
def train(epoch):
global idloss,best_val_loss
model.train() # Turn on the train mode
total_loss = 0.
start_time = time.time()
total_loss_list = torch.zeros((1,7))
for Batch, d in enumerate(train_loader): # there are two 'BATCH', 'Batch' includes batch_size*TreePoint/batchSize/bptt 'batch'es.
batch = 0
train_data = d[0].reshape((batchSize,-1,4,6)).to(device).permute(1,0,2,3) #shape [TreePoint*batch_size(data)/batch_size,batch_size,7,6]
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
for index, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets,dataFeat = get_batch(train_data, i)#data [35,20]
optimizer.zero_grad()
if data.size(0) != bptt:
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
output = model(data, src_mask,dataFeat) #output: [bptt,batch size,255]
loss = criterion(output.view(-1, ntokens), targets)/math.log(2)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
batch = batch+1
if batch % log_interval == 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
total_loss_list = " - "
printl('| epoch {:3d} | Batch {:3d} | {:4d}/{:4d} batches | '
'lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | losslist {} | ppl {:8.2f}'.format(
epoch, Batch, batch, len(train_data) // bptt, scheduler.get_last_lr()[0],
elapsed * 1000 / log_interval,
cur_loss,total_loss_list, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()
writer.add_scalar('train_loss', cur_loss,idloss)
idloss+=1
if Batch%10==0:
save(epoch*100000+Batch,saveDict={'encoder':model.state_dict(),'idloss':idloss,'epoch':epoch,'best_val_loss':best_val_loss},modelDir=checkpointPath)
# train
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train(epoch)
printl('-' * 89)
scheduler.step()
printl('-' * 89)