forked from state-spaces/s4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
378 lines (305 loc) · 12.6 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
'''
Train an S4 model on sequential CIFAR10 / sequential MNIST with PyTorch for demonstration purposes.
This code borrows heavily from https://github.com/kuangliu/pytorch-cifar.
This file only depends on the standalone S4 layer
available in /models/s4/
* Train standard sequential CIFAR:
python -m example
* Train sequential CIFAR grayscale:
python -m example --grayscale
* Train MNIST:
python -m example --dataset mnist --d_model 256 --weight_decay 0.0
The `S4Model` class defined in this file provides a simple backbone to train S4 models.
This backbone is a good starting point for many problems, although some tasks (especially generation)
may require using other backbones.
The default CIFAR10 model trained by this file should get
89+% accuracy on the CIFAR10 test set in 80 epochs.
Each epoch takes approximately 7m20s on a T4 GPU (will be much faster on V100 / A100).
'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models.s4.s4 import S4Block as S4 # Can use full version instead of minimal S4D standalone below
from models.s4.s4d import S4D
from tqdm.auto import tqdm
# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
dropout_fn = nn.Dropout1d
else:
dropout_fn = nn.Dropout2d
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# Optimizer
parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
parser.add_argument('--weight_decay', default=0.01, type=float, help='Weight decay')
# Scheduler
# parser.add_argument('--patience', default=10, type=float, help='Patience for learning rate scheduler')
parser.add_argument('--epochs', default=100, type=float, help='Training epochs')
# Dataset
parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10'], type=str, help='Dataset')
parser.add_argument('--grayscale', action='store_true', help='Use grayscale CIFAR10')
# Dataloader
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers to use for dataloader')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
# Model
parser.add_argument('--n_layers', default=4, type=int, help='Number of layers')
parser.add_argument('--d_model', default=128, type=int, help='Model dimension')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
parser.add_argument('--prenorm', action='store_true', help='Prenorm')
# General
parser.add_argument('--resume', '-r', action='store_true', help='Resume from checkpoint')
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print(f'==> Preparing {args.dataset} data..')
def split_train_val(train, val_split):
train_len = int(len(train) * (1.0-val_split))
train, val = torch.utils.data.random_split(
train,
(train_len, len(train) - train_len),
generator=torch.Generator().manual_seed(42),
)
return train, val
if args.dataset == 'cifar10':
if args.grayscale:
transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
transforms.Lambda(lambda x: x.view(1, 1024).t())
])
else:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Lambda(lambda x: x.view(3, 1024).t())
])
# S4 is trained on sequences with no data augmentation!
transform_train = transform_test = transform
trainset = torchvision.datasets.CIFAR10(
root='./data/cifar/', train=True, download=True, transform=transform_train)
trainset, _ = split_train_val(trainset, val_split=0.1)
valset = torchvision.datasets.CIFAR10(
root='./data/cifar/', train=True, download=True, transform=transform_test)
_, valset = split_train_val(valset, val_split=0.1)
testset = torchvision.datasets.CIFAR10(
root='./data/cifar/', train=False, download=True, transform=transform_test)
d_input = 3 if not args.grayscale else 1
d_output = 10
elif args.dataset == 'mnist':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(1, 784).t())
])
transform_train = transform_test = transform
trainset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform_train)
trainset, _ = split_train_val(trainset, val_split=0.1)
valset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform_test)
_, valset = split_train_val(valset, val_split=0.1)
testset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform_test)
d_input = 1
d_output = 10
else: raise NotImplementedError
# Dataloaders
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(
valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(
testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
class S4Model(nn.Module):
def __init__(
self,
d_input,
d_output=10,
d_model=256,
n_layers=4,
dropout=0.2,
prenorm=False,
):
super().__init__()
self.prenorm = prenorm
# Linear encoder (d_input = 1 for grayscale and 3 for RGB)
self.encoder = nn.Linear(d_input, d_model)
# Stack S4 layers as residual blocks
self.s4_layers = nn.ModuleList()
self.norms = nn.ModuleList()
self.dropouts = nn.ModuleList()
for _ in range(n_layers):
self.s4_layers.append(
S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, args.lr))
)
self.norms.append(nn.LayerNorm(d_model))
self.dropouts.append(dropout_fn(dropout))
# Linear decoder
self.decoder = nn.Linear(d_model, d_output)
def forward(self, x):
"""
Input x is shape (B, L, d_input)
"""
x = self.encoder(x) # (B, L, d_input) -> (B, L, d_model)
x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L)
for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
# Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)
z = x
if self.prenorm:
# Prenorm
z = norm(z.transpose(-1, -2)).transpose(-1, -2)
# Apply S4 block: we ignore the state input and output
z, _ = layer(z)
# Dropout on the output of the S4 block
z = dropout(z)
# Residual connection
x = z + x
if not self.prenorm:
# Postnorm
x = norm(x.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(-1, -2)
# Pooling: average pooling over the sequence length
x = x.mean(dim=1)
# Decode the outputs
x = self.decoder(x) # (B, d_model) -> (B, d_output)
return x
# Model
print('==> Building model..')
model = S4Model(
d_input=d_input,
d_output=d_output,
d_model=args.d_model,
n_layers=args.n_layers,
dropout=args.dropout,
prenorm=args.prenorm,
)
model = model.to(device)
if device == 'cuda':
cudnn.benchmark = True
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
model.load_state_dict(checkpoint['model'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
def setup_optimizer(model, lr, weight_decay, epochs):
"""
S4 requires a specific optimizer setup.
The S4 layer (A, B, C, dt) parameters typically
require a smaller learning rate (typically 0.001), with no weight decay.
The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
and weight decay (if desired).
"""
# All parameters in the model
all_parameters = list(model.parameters())
# General parameters don't contain the special _optim key
params = [p for p in all_parameters if not hasattr(p, "_optim")]
# Create an optimizer with the general parameters
optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)
# Add parameters with special hyperparameters
hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
hps = [
dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
] # Unique dicts
for hp in hps:
params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
optimizer.add_param_group(
{"params": params, **hp}
)
# Create a lr scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# Print optimizer info
keys = sorted(set([k for hp in hps for k in hp.keys()]))
for i, g in enumerate(optimizer.param_groups):
group_hps = {k: g.get(k, None) for k in keys}
print(' | '.join([
f"Optimizer group {i}",
f"{len(g['params'])} tensors",
] + [f"{k} {v}" for k, v in group_hps.items()]))
return optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer, scheduler = setup_optimizer(
model, lr=args.lr, weight_decay=args.weight_decay, epochs=args.epochs
)
###############################################################################
# Everything after this point is standard PyTorch training!
###############################################################################
# Training
def train():
model.train()
train_loss = 0
correct = 0
total = 0
pbar = tqdm(enumerate(trainloader))
for batch_idx, (inputs, targets) in pbar:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_description(
'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
(batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
)
def eval(epoch, dataloader, checkpoint=False):
global best_acc
model.eval()
eval_loss = 0
correct = 0
total = 0
with torch.no_grad():
pbar = tqdm(enumerate(dataloader))
for batch_idx, (inputs, targets) in pbar:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
eval_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_description(
'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
(batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
)
# Save checkpoint.
if checkpoint:
acc = 100.*correct/total
if acc > best_acc:
state = {
'model': model.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_acc = acc
return acc
pbar = tqdm(range(start_epoch, args.epochs))
for epoch in pbar:
if epoch == 0:
pbar.set_description('Epoch: %d' % (epoch))
else:
pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
train()
val_acc = eval(epoch, valloader, checkpoint=True)
eval(epoch, testloader)
scheduler.step()
# print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")