-
Notifications
You must be signed in to change notification settings - Fork 18
/
trainer.py
executable file
·43 lines (33 loc) · 1.04 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import absolute_import, division, print_function
import numpy as np
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import json
from utils import *
from kitti_utils import *
from layers import *
import datasets
import networks
from IPython import embed
class Trainer:
#######################
#### MonoViT ##
######################
#self.model_optimizer = optim.AdamW(self.parameters_to_train, self.opt.learning_rate)
self.params = [ {
"params":self.parameters_to_train,
"lr": 1e-4
#"weight_decay": 0.01
},
{
"params": list(self.models["encoder"].parameters()),
"lr": self.opt.learning_rate
#"weight_decay": 0.01
} ]
self.model_optimizer = optim.AdamW(self.params)
self.model_lr_scheduler = optim.lr_scheduler.ExponentialLR(
self.model_optimizer,0.9)