-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
80 lines (56 loc) · 2.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import ast
import sys
from tensorflow import keras
def insert_blanks(y, blank, num_blanks_at_start=1):
# Insert blanks at alternate locations in the labelling (blank is blank)
y1 = [blank] * num_blanks_at_start
for char in y:
y1 += [char, blank]
return y1
def read_args(files, default='configs/default.ast'):
with open(default, 'r') as dfp:
args = ast.literal_eval(dfp.read())
for config_file in files:
with open(config_file, 'r') as cfp:
override_args = ast.literal_eval(cfp.read())
for key in args:
if key in override_args:
try:
args[key].update(override_args[key])
except AttributeError:
args[key] = override_args[key]
return args
def pprint_probs(probs):
for row in (10 * probs).astype(int):
for val in row:
print('{:+04d}'.format(val), end='')
print()
def write_dict(d, f=sys.stdout, level=0):
tabs = '\t' * level
print(file=f)
for k in sorted(d.keys()):
v = d[k]
print('{}{}: '.format(tabs, k), file=f, end='')
if type(v) is dict:
write_dict(v, f, level+1)
else:
print('{}'.format(v), file=f)
def couple(a):
return a if (type(a) is tuple) else (a, a)
def actv(name):
if name.startswith('relu') and len(name) > 4:
def _f(x):
return keras.activations.relu(x, alpha=int(name[4:]) / 100.)
_f.name = name
return _f
return name
def cp(name, num_filters, kernel_size, activation, **kwargs):
return {'name': name,
'filters': num_filters, 'kernel_size': couple(kernel_size),
'activation': actv(activation), 'padding': 'same', **kwargs}
def pp(name, pool_sz):
return {'name': name, 'pool_size': couple(pool_sz)}
def den(name, nunits, activation, **kwargs):
return {'name': name, 'units': nunits, 'activation': actv(activation), **kwargs}
def rnn(name, nunits, **kwargs):
return {'name': name, 'units': nunits, **kwargs}