-
Notifications
You must be signed in to change notification settings - Fork 6
/
compress.py
119 lines (100 loc) · 3.41 KB
/
compress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pickle
import argparse
from matplotlib.pyplot import streamplot
import torch
from utils.mask_utils import percentile, dilate_binarize
import logging
import coloredlogs
import utils.compressor as compressor
from pdb import set_trace
from torch.utils.tensorboard import SummaryWriter
def main(args):
# define logger
logger = logging.getLogger("compress")
# load the raw mask
with open(f"{args.output}.rawmask", "rb") as f:
mask = pickle.load(f)
# smooth the mask first.
# we only sample the mask every args.smooth_frames images.
splitted_mask = mask.split(args.smooth_frames)
for i in range(len(splitted_mask)):
cur_slice = splitted_mask[i]
if i < len(splitted_mask) - 1:
next_slice = splitted_mask[i+1]
cur_slice[:, :, :, :] = 0.5 * (cur_slice[0] + next_slice[0]).unsqueeze(0)
else:
cur_slice[:, :, :, :] = 0.5 * (cur_slice[0] + cur_slice[-1]).unsqueeze(0)
# # Two types of knobs: heat value threshold and knobs
if args.bound:
mask = dilate_binarize(mask, args.bound, args.pad, cuda=False)
else:
mask = dilate_binarize(mask, percentile(mask, args.perc), args.pad, cuda=False)
# compress the video and log the raw images before encoding
writer = SummaryWriter(f"runs/{args.output}")
getattr(compressor, args.compressor)(mask, args, logger, writer)
if __name__ == "__main__":
# set the format of the logger
coloredlogs.install(
fmt="%(asctime)s [%(levelname)s] %(name)s:%(funcName)s[%(lineno)s] -- %(message)s",
datefmt="%H:%M:%S",
level="INFO",
)
parser = argparse.ArgumentParser()
parser.add_argument(
'-o',
'--output',
help='The output mp4 file name. Will attach a args file that contain args for decoding purpose.',
type=str,
required=True
)
parser.add_argument(
'-p',
'--pad',
help='The padding size that pads extra high quality regions around existing high quality regions',
type=int,
required=True
)
parser.add_argument(
'-c',
'--compressor',
help='The compressor used to compress the video.',
type=str,
required=True
)
parser.add_argument(
"-s",
"--source",
type=str,
help="The source to encode the video.",
required=True,
)
parser.add_argument(
"--tile_size", type=int, help="The tile size of the mask.", default=16
)
parser.add_argument(
'--preserve',
help='Preserve source png folders for debugging purpose.',
action='store_true'
)
parser.add_argument(
"--smooth_frames",
type=int,
help="Proposing one single mask for smooth_frames many frames",
default=30,
)
parser.add_argument(
"--visualize_step_size",
type=int,
help="Proposing one single mask for smooth_frames many frames",
default=100,
)
parser.add_argument("--qp", type=int, required=True)
action = parser.add_mutually_exclusive_group(required=True)
action.add_argument(
"--bound", type=float, help="The lower bound for the mask. Exclusive with --perc",
)
action.add_argument(
"--perc", type=float, help="The percentage of pixels in high quality. Exclusive with --bound"
)
args = parser.parse_args()
main(args)