forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
common.py
127 lines (102 loc) · 3.11 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
import numpy as np
import cv2
from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import transform
from tensorpack.utils import logger
import pycocotools.mask as cocomask
import config
class DataFromListOfDict(RNGDataFlow):
def __init__(self, lst, keys, shuffle=False):
self._lst = lst
self._keys = keys
self._shuffle = shuffle
self._size = len(lst)
def size(self):
return self._size
def get_data(self):
if self._shuffle:
self.rng.shuffle(self._lst)
for dic in self._lst:
dp = [dic[k] for k in self._keys]
yield dp
class CustomResize(transform.TransformAugmentorBase):
"""
Try resizing the shortest edge to a certain number
while avoiding the longest edge to exceed max_size.
"""
def __init__(self, size, max_size, interp=cv2.INTER_LINEAR):
"""
Args:
size (int): the size to resize the shortest edge to.
max_size (int): maximum allowed longest edge.
"""
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
scale = self.size * 1.0 / min(h, w)
if h < w:
newh, neww = self.size, scale * w
else:
newh, neww = scale * h, self.size
if max(newh, neww) > self.max_size:
scale = self.max_size * 1.0 / max(newh, neww)
newh = newh * scale
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return transform.ResizeTransform(h, w, newh, neww, self.interp)
def box_to_point8(boxes):
"""
Args:
boxes: nx4
Returns:
(nx4)x2
"""
b = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]]
b = b.reshape((-1, 2))
return b
def point8_to_box(points):
"""
Args:
points: (nx4)x2
Returns:
nx4 boxes (x1y1x2y2)
"""
p = points.reshape((-1, 4, 2))
minxy = p.min(axis=1) # nx2
maxxy = p.max(axis=1) # nx2
return np.concatenate((minxy, maxxy), axis=1)
def segmentation_to_mask(polys, height, width):
"""
Convert polygons to binary masks.
Args:
polys: a list of nx2 float array
Returns:
a binary matrix of (height, width)
"""
polys = [p.flatten().tolist() for p in polys]
rles = cocomask.frPyObjects(polys, height, width)
rle = cocomask.merge(rles)
return cocomask.decode(rle)
def clip_boxes(boxes, shape):
"""
Args:
boxes: (...)x4, float
shape: h, w
"""
orig_shape = boxes.shape
boxes = boxes.reshape([-1, 4])
h, w = shape
boxes[:, [0, 1]] = np.maximum(boxes[:, [0, 1]], 0)
boxes[:, 2] = np.minimum(boxes[:, 2], w)
boxes[:, 3] = np.minimum(boxes[:, 3], h)
return boxes.reshape(orig_shape)
def print_config():
logger.info("Config: ------------------------------------------")
for k in dir(config):
if k == k.upper():
logger.info("{} = {}".format(k, getattr(config, k)))
logger.info("--------------------------------------------------")