-
Notifications
You must be signed in to change notification settings - Fork 4
/
argparser.py
77 lines (67 loc) · 2.09 KB
/
argparser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
## Copyright (C) 2019, Huan Zhang <[email protected]>
## Hongge Chen <[email protected]>
## Chaowei Xiao <[email protected]>
##
## This program is licenced under the BSD 2-Clause License,
## contained in the LICENCE file in this directory.
##
import os
import torch
import random
import numpy as np
import argparse
import ast
def isfloat(value):
try:
float(value)
return True
except ValueError:
return False
def isint(value):
try:
int(value)
return True
except ValueError:
return False
def argparser(seed = 2019):
parser = argparse.ArgumentParser()
# configure file
parser.add_argument('--config', default="UNSPECIFIED.json")
parser.add_argument('--model_subset', type=int, nargs='+',
help='Use only a subset of models in config file. Pass a list of numbers starting with 0, like --model_subset 0 1 3 5')
parser.add_argument('--path_prefix', type=str, default="", help="override path prefix")
parser.add_argument('--seed', type=int, default=seed)
parser.add_argument('overrides', type=str, nargs='*',
help='overriding config dict')
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
# for dual norm computation, we will have 1 / 0.0 = inf
np.seterr(divide='ignore')
overrides_dict = {}
for o in args.overrides:
key, val = o.strip().split("=")
d = overrides_dict
last_key = key
if ":" in key:
keys = key.split(":")
for k in keys[:-1]:
if k not in d:
d[k] = {}
d = d[k]
last_key = keys[-1]
if val == "true":
val = True
elif val == "false":
val = False
elif isint(val):
val = int(val)
elif isfloat(val):
val = float(val)
else:
val = ast.literal_eval(val)
d[last_key] = val
args.overrides_dict = overrides_dict
return args