-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
81 lines (64 loc) · 3.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import wandb
import torch
import torchvision
from config import set_params
from kws.utils import set_random_seed, transforms
from kws.utils.data import SpeechCommandsDataset, load_data, split_data
from kws.model import treasure_net
from kws.train import train
def main():
# set parameters and random seed
params = set_params()
set_random_seed(params['random_seed'])
params['device'] = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
if params['verbose']:
print('Using device', params['device'])
# load and split data
data = load_data(params['data_root'])
train_data, valid_data = split_data(data, params['valid_ratio'])
if params['verbose']:
print('Data loaded and split')
# create dataloaders
train_transform = torchvision.transforms.Compose([
transforms.RandomVolume(gain_db=params['gain_db']),
transforms.RandomPitchShift(sample_rate=params['sample_rate'],
pitch_shift=params['pitch_shift']),
torchvision.transforms.RandomChoice([
transforms.GaussianNoise(scale=params['noise_scale']),
transforms.AudioNoise(scale=params['audio_scale'],
sample_rate=params['sample_rate']),
]),
])
train_dataset = SpeechCommandsDataset(root=params['data_root'], labels=train_data,
keywords=params['keywords'], audio_seconds=params['audio_seconds'],
sample_rate=params['sample_rate'], transform=train_transform)
valid_dataset = SpeechCommandsDataset(root=params['data_root'], labels=valid_data,
keywords=params['keywords'], audio_seconds=params['audio_seconds'],
sample_rate=params['sample_rate'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
num_workers=params['num_workers'], shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=params['batch_size'],
num_workers=params['num_workers'], shuffle=True)
if params['verbose']:
print('Data loaders prepared')
# initialize model and optimizer
model = treasure_net(params).to(params['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])
if params['load_model']:
checkpoint = torch.load(params['model_checkpoint'])
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optim_state_dict'])
if params['verbose']:
print('Model and optimizer initialized')
# create checkpoints folder
if not os.path.isdir(params['checkpoint_dir']):
os.mkdir(params['checkpoint_dir'])
# initialize wandb
if params['use_wandb']:
wandb.init(project=params['wandb_project'])
wandb.watch(model)
# train
train(model, optimizer, train_loader, valid_loader, params)
if __name__ == '__main__':
main()