-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
34 lines (26 loc) · 816 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from __future__ import print_function
from models.litmodel import LitModel
from models.fcn32s import FCN32s
from models.vgg import VGGNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models
from torchvision.models.vgg import VGG
import pytorch_lightning as pl
from torch.optim.rmsprop import RMSprop
from dataloader import CustomDataset, lit_custom_data
from pytorch_lightning import loggers
from configs import Configs
import os
# os.environ["OPENBLAS_MAIN_FREE"] = '1'
if __name__ == '__main__':
hparams = {
'lr': 0.0019054607179632484
}
model = LitModel(hparams)
dataset = lit_custom_data()
trainer = pl.Trainer(gpus=1, max_epochs=120)
# trainer.tune(model, dataset)
trainer.fit(model, dataset)