-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathparam.py
92 lines (70 loc) · 3.23 KB
/
param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# coding=utf-8
# Copy from lxmert with modifications
import argparse
import random
import numpy as np
import torch
def get_optimizer(optim):
# Bind the optimizer
if optim == 'rms':
print("Optimizer: Using RMSProp")
optimizer = torch.optim.RMSprop
elif optim == 'adam':
print("Optimizer: Using Adam")
optimizer = torch.optim.Adam
elif optim == 'adamax':
print("Optimizer: Using Adamax")
optimizer = torch.optim.Adamax
elif optim == 'sgd':
print("Optimizer: sgd")
optimizer = torch.optim.SGD
elif 'bert' in optim:
optimizer = 'bert' # The bert optimizer will be bind later.
else:
assert False, "Please add your optimizer %s in the list." % optim
return optimizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='lxmert')
# Data Splits
parser.add_argument("--train", default='train,nominival')
parser.add_argument("--valid", default='minival')
parser.add_argument("--test", default=None)
# Training Hyper-parameters
parser.add_argument('--batchSize', dest='batch_size', type=int, default=32)
parser.add_argument('--optim', default='bert')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--seed', type=int, default=9595, help='random seed')
parser.add_argument('--max_seq_length', type=int, default=20, help='max sequence_length')
# Debugging
parser.add_argument('--output', type=str, default='models/trained/')
parser.add_argument("--fast", action='store_const', default=False, const=True)
parser.add_argument("--tiny", action='store_const', default=False, const=True)
parser.add_argument("--tqdm", action='store_const', default=True, const=True)
# Model Loading
parser.add_argument('--load_trained', type=str, default=None,
help='Load the model (usually the fine-tuned model).')
parser.add_argument('--load_pretrained', dest='load_pretrained', type=str, default=None,
help='Load the pre-trained LXMERT/VisualBERT/UNITER model.')
parser.add_argument("--fromScratch", dest='from_scratch', action='store_const', default=False, const=True,
help='If none of the --load_trained, --load_pretrained, is set, '
'the model would be trained from scratch. If --fromScratch is'
' not specified, the model would load BERT-pre-trained weights by'
' default. ')
# Optimization
parser.add_argument("--mceLoss", dest='mce_loss', action='store_const', default=False, const=True)
# Training configuration
parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
parser.add_argument("--numWorkers", dest='num_workers', default=0)
# Parse the arguments.
args = parser.parse_args()
# Bind optimizer class.
args.optimizer = get_optimizer(args.optim)
# Set seeds
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
return args
args = parse_args()