-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble.py
108 lines (88 loc) · 3.12 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import mmcv
import torch
import torch.nn.functional as F
from argparse import ArgumentParser
pipe_stcn = '/tmp/stcn.pkl'
pipe_aot = '/tmp/aot.pkl'
def mean_merge_back(aot, stcn):
get_aot = torch.from_numpy(aot)
get_stcn = torch.from_numpy(stcn)
N, _, h, w = get_stcn.shape
*_, H, W = get_aot.shape
get_stcn = F.interpolate(get_stcn, (H, W), mode='bilinear', align_corners=True)
merge_logit = (get_stcn + get_aot[:N]) / 2
send_aot = torch.ones_like(get_aot) * (-1e+4)
send_aot[:N] = merge_logit
send_stcn = F.interpolate(merge_logit, (h, w), mode='bilinear', align_corners=True)
return send_aot.numpy(), send_stcn.numpy()
# def just_get_merge_output(aot, stcn, output_dir):
# get_aot = torch.from_numpy(aot)
# get_stcn = torch.from_numpy(stcn)
# N, _, h, w = get_stcn.shape
# *_, H, W = get_aot.shape
# get_stcn = F.interpolate(get_stcn, (H, W), mode='bilinear', align_corners=True)
# merge_logit = (get_stcn + get_aot[:N]) / 2
# return aot, stcn
def main():
parser = ArgumentParser()
parser.add_argument('--off', action='store_true')
parser.add_argument('--output')
args = parser.parse_args()
if args.off:
try:
os.remove(pipe_stcn)
except:
pass
try:
os.remove(pipe_aot)
except:
pass
return
else:
try:
os.mkfifo(pipe_stcn)
except FileExistsError:
os.remove(pipe_stcn)
os.mkfifo(pipe_stcn)
try:
os.mkfifo(pipe_aot)
except FileExistsError:
os.remove(pipe_aot)
os.mkfifo(pipe_aot)
get_stcn = mmcv.load(pipe_stcn)
get_aot = mmcv.load(pipe_aot)
for frame_s in get_stcn: pass
for frame_a in get_aot: pass
f_last = ''
while True:
if frame_a == f_last and frame_s == f_last:
print(f'stcn repeat at {frame_s}', flush=True)
mmcv.dump(get_stcn[frame_s], pipe_stcn)
get_stcn = mmcv.load(pipe_stcn)
for frame_s in get_stcn: pass
print(f'aot repeat at {frame_a}', flush=True)
mmcv.dump(get_aot[frame_a], pipe_aot)
get_aot = mmcv.load(pipe_aot)
for frame_a in get_aot: pass
elif frame_a == frame_s:
f_last = frame_a
send_aot, send_stcn = mean_merge_back(get_aot[f_last], get_stcn[f_last])
mmcv.dump(send_aot, pipe_aot)
mmcv.dump(send_stcn, pipe_stcn)
get_stcn = mmcv.load(pipe_stcn)
get_aot = mmcv.load(pipe_aot)
for frame_s in get_stcn: pass
for frame_a in get_aot: pass
elif frame_a > frame_s:
print(f'stcn repeat at {frame_s}', flush=True)
mmcv.dump(get_stcn[frame_s], pipe_stcn)
get_stcn = mmcv.load(pipe_stcn)
for frame_s in get_stcn: pass
else:
print(f'aot repeat at {frame_a}', flush=True)
mmcv.dump(get_aot[frame_a], pipe_aot)
get_aot = mmcv.load(pipe_aot)
for frame_a in get_aot: pass
if __name__ == '__main__':
main()