This repository has been archived by the owner on Jul 18, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.lua
59 lines (50 loc) · 1.93 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
--[[
This code is part of Ultrasound-Nerve-Segmentation Program
Copyright (c) 2016, Qure.AI, Pvt. Ltd.
All rights reserved.
Main file
--]]
require 'torch'
require 'paths'
require 'optim'
require 'nn'
require 'cunn'
require 'cudnn'
tnt = require 'torchnet'
torch.setnumthreads(1) -- speed up
torch.setdefaulttensortype('torch.FloatTensor')
-- command line instructions reading
local cmd = torch.CmdLine()
cmd:text()
cmd:text('Torch-7 context encoder training script')
cmd:text()
cmd:text('Options:')
cmd:option('-dataset','data/train.h5','Training dataset to be used')
cmd:option('-model','models/unet.lua','Path of the model to be used')
cmd:option('-trainSize',100,'Size of the training dataset to be used, -1 if complete dataset has to be used')
cmd:option('-valSize',25,'Size of the validation dataset to be used, -1 if complete validation dataset has to be used')
cmd:option('-trainBatchSize',64,'Size of the batch to be used for training')
cmd:option('-valBatchSize',32,'Size of the batch to be used for validation')
cmd:option('-savePath','data/saved_models/','Path to save models')
cmd:option('-optimMethod','sgd','Algorithm to be used for learning - sgd | adam')
cmd:option('-maxepoch',250,'Epochs for training')
cmd:option('-cvParam',2,'Cross validation parameter used to segregate data based on patient number')
--- Main execution script
function main(opt)
opt.trainSize = opt.trainSize==-1 and nil or opt.trainSize
opt.valSize = opt.valSize==-1 and nil or opt.valSize
-- loads the data loader
require 'dataloader.lua'
local dl = DataLoader(opt)
local trainDataset = dl:GetData('train',opt.trainSize)
local valDataset = dl:GetData('val',opt.valSize)
opt.trainDataset = trainDataset
opt.valDataset = valDataset
opt.dataset = paths.basename(opt.dataset,'.h5')
print(opt)
require 'machine.lua'
local m = Machine(opt)
m:train()
end
local opt = cmd:parse(arg or {}) -- Table containing all the above options
main(opt)