-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtensorstat
executable file
·96 lines (78 loc) · 2.2 KB
/
tensorstat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/python3
import argparse
import sys
import time
from math import inf
import numpy as np
import tensorcom
parser = argparse.ArgumentParser(
"""
Compute statistics over a tensor in a tensorcom input stream.
Each item in a tensorcom input stream is usually a list of
tensors, each representing a batch.
"""
)
parser.add_argument("input", nargs="*")
parser.add_argument("-b", "--unbatched", action="store_true")
parser.add_argument("-c", "--count", type=int, default=20)
parser.add_argument("-r", "--raw", action="store_true")
args = parser.parse_args()
if args.input == []:
args.input = ["zsub://127.0.0.1:7880"]
source = tensorcom.Connection(args.input, device=None, raw=args.raw)
print("reading batches...\n")
class Stats(object):
def __init__(self):
self.count = 0
self.lo = inf
self.hi = -inf
self.sx = 0
self.sx2 = 0
self.n = 0
def __iadd__(self, x):
self.count += 1
self.lo = min(self.lo, np.amin(x))
self.hi = max(self.hi, np.amax(x))
self.sx += np.sum(x)
self.sx2 += np.sum(x ** 2)
self.n += x.size
return self
def summary(self):
return "{:d} [{:.3g} {:.3g}] mean={:.3g} std={:.3g} n={:d}".format(
self.count,
self.lo,
self.hi,
self.sx / self.n,
(self.sx2 / self.n - (self.sx / self.n) ** 2) ** 0.5,
self.n,
)
shapes = [set() for _ in range(10)]
stats = [Stats() for _ in range(10)]
ninputs = 0
start = time.time()
for i, batch in enumerate(source.items()):
if i >= args.count:
break
if args.raw:
continue
ninputs = max(ninputs, len(batch))
for i, a in enumerate(batch):
if not isinstance(a, np.ndarray):
continue
shapes[i].add((str(a.dtype),) + tuple(a.shape))
stats[i] += a.astype(np.float32)
finish = time.time()
if args.raw:
print(source.stats.summary())
sys.exit(0)
print("Source:")
print(source.stats.summary())
print()
for i in range(ninputs):
print("=== Input {} ===\n".format(i))
if stats[i].count == 0:
print("not a tensor")
else:
print(stats[i].summary())
print(shapes[i])
print()