-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
59 lines (54 loc) · 3.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import time
import torch
import torch.nn as nn
import os
import sys
from tqdm import tqdm
#from utils_mean import *
from config import Config
from model import CSRNet
from dataset import create_train_dataloader,create_test_dataloader
from utils import denormalize
if __name__=="__main__":
cfg = Config() # configuration
model = CSRNet().to(cfg.device) # model
criterion = nn.MSELoss(size_average=False) # objective
optimizer = torch.optim.Adam(model.parameters(),lr=cfg.lr) # optimizer
train_dataloader = create_train_dataloader(cfg.dataset_root, use_flip=True, batch_size=cfg.batch_size)
test_dataloader = create_test_dataloader(cfg.dataset_root) # dataloader
min_mae = sys.maxsize
min_mae_epoch = -1
#train_loss = 0.0
for epoch in range(1, cfg.epochs): # start training
model.train()
epoch_loss = 0.0
for i, data in enumerate(tqdm(train_dataloader)):
image = data['image'].to(cfg.device)
gt_densitymap = data['densitymap'].to(cfg.device)
et_densitymap = model(image) # forward propagation
loss = criterion(et_densitymap,gt_densitymap) # calculate loss
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward() # back propagation
optimizer.step() # update network parameters
cfg.writer.add_scalar('Train_Loss', epoch_loss/len(train_dataloader), epoch)
#train_loss = epoch_loss/len(train_dataloader)
model.eval()
with torch.no_grad():
epoch_mae = 0.0
for i, data in enumerate(tqdm(test_dataloader)):
image = data['image'].to(cfg.device)
gt_densitymap = data['densitymap'].to(cfg.device)
et_densitymap = model(image).detach() # forward propagation
mae = abs(et_densitymap.data.sum()-gt_densitymap.data.sum())
epoch_mae += mae.item()
epoch_mae /= len(test_dataloader)
if epoch_mae < min_mae:
min_mae, min_mae_epoch = epoch_mae, epoch
torch.save(model.state_dict(), os.path.join(cfg.checkpoints,str(epoch)+".pth")) # save checkpoints
print('Epoch ', epoch, ' MAE: ', epoch_mae, ' Min MAE: ', min_mae, ' Min Epoch: ', min_mae_epoch) # print information
cfg.writer.add_scalar('Val_MAE', epoch_mae, epoch)
cfg.writer.add_image(str(epoch)+'/Image', denormalize(image[0].cpu()))
cfg.writer.add_image(str(epoch)+'/Estimate density count:'+ str('%.2f'%(et_densitymap[0].cpu().sum())), et_densitymap[0]/torch.max(et_densitymap[0]))
cfg.writer.add_image(str(epoch)+'/Ground Truth count:'+ str('%.2f'%(gt_densitymap[0].cpu().sum())), gt_densitymap[0]/torch.max(gt_densitymap[0]))