-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprint_2d_tensor_with_ids.py
59 lines (46 loc) · 1.5 KB
/
print_2d_tensor_with_ids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
def print_2d_tensor_with_ids(tensor):
assert len(tensor.shape) == 2
n,m = tensor.shape
lines = str(tensor).split('\n')
# c= [x for x in c]
lines_ = []
for i,x in enumerate(lines):
x = "{:02d} ".format(i) + x[8:]
lines_.append(x)
lines_[-1] = lines_[-1][:-2]
len_ = int((len(lines_[0]) -3) / m) - 2 # 3 is "{:02d} " and 2 is "{:02d}"
head = " "*3
for i in range(m):
len_prefix = len_//2
len_sufix = len_ - len_prefix
head += " " * len_prefix + "{:02d}".format(i) + " " * len_sufix
lines_ = [head] + lines_
for x in lines_:
print(x)
def repr_2d_tensor_with_ids(tensor) -> str:
assert len(tensor.shape) == 2
n,m = tensor.shape
lines = str(tensor).split('\n')
# c= [x for x in c]
lines_ = []
for i,x in enumerate(lines):
x = "{:02d} ".format(i) + x[8:]
lines_.append(x)
# lines_[-1] = lines_[-1][:-2]
len_ = int((len(lines_[0]) -3) / m) - 2 # 3 is "{:02d} " and 2 is "{:02d}"
head = " "*3
for i in range(m):
len_prefix = len_//2
len_sufix = len_ - len_prefix
head += " " * len_prefix + "{:02d}".format(i) + " " * len_sufix
lines_ = [head] + lines_
return "\n".join(lines_)
def main():
# xx = torch.rand(size=(1,1))
xx = torch.randint(0,91,size=(12,16))
print_2d_tensor_with_ids(xx)
if __name__ == "__main__":
torch.manual_seed(111)
torch.set_printoptions(linewidth=160)
demo()