-
Notifications
You must be signed in to change notification settings - Fork 0
/
oid_hierarchical_labels_expansion.py
199 lines (170 loc) · 6.78 KB
/
oid_hierarchical_labels_expansion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""An executable to expand hierarchically image-level labels and boxes.
Example usage:
python models/research/object_detection/dataset_tools/\
oid_hierarchical_labels_expansion.py \
--json_hierarchy_file=<path to JSON hierarchy> \
--input_annotations=<input csv file> \
--output_annotations=<output csv file> \
--annotation_type=<1 (for boxes) or 2 (for image-level labels)>
"""
import argparse
import json
def _update_dict(initial_dict, update):
"""Updates dictionary with update content.
Args:
initial_dict: initial dictionary.
update: updated dictionary.
"""
for key, value_list in update.iteritems():
if key in initial_dict:
initial_dict[key].extend(value_list)
else:
initial_dict[key] = value_list
def _build_plain_hierarchy(hierarchy, skip_root=False):
"""Expands tree hierarchy representation to parent-child dictionary.
Args:
hierarchy: labels hierarchy as JSON file.
skip_root: if true skips root from the processing (done for the case when all
classes under hierarchy are collected under virtual node).
Returns:
keyed_parent - dictionary of parent - all its children nodes.
keyed_child - dictionary of children - all its parent nodes
children - all children of the current node.
"""
all_children = []
all_keyed_parent = {}
all_keyed_child = {}
if 'Subcategory' in hierarchy:
for node in hierarchy['Subcategory']:
keyed_parent, keyed_child, children = _build_plain_hierarchy(node)
# Update is not done through dict.update() since some children have multi-
# ple parents in the hiearchy.
_update_dict(all_keyed_parent, keyed_parent)
_update_dict(all_keyed_child, keyed_child)
all_children.extend(children)
if not skip_root:
all_keyed_parent[hierarchy['LabelName']] = all_children
all_children = [hierarchy['LabelName']] + all_children
for child, _ in all_keyed_child.iteritems():
all_keyed_child[child].append(hierarchy['LabelName'])
all_keyed_child[hierarchy['LabelName']] = []
return all_keyed_parent, all_keyed_child, all_children
class OIDHierarchicalLabelsExpansion(object):
""" Main class to perform labels hierachical expansion."""
def __init__(self, hierarchy):
"""Constructor.
Args:
hierarchy: labels hierarchy as JSON object.
"""
self._hierarchy_keyed_parent, self._hierarchy_keyed_child, _ = (
_build_plain_hierarchy(hierarchy, skip_root=True))
def expand_boxes_from_csv(self, csv_row):
"""Expands a row containing bounding boxes from CSV file.
Args:
csv_row: a single row of Open Images released groundtruth file.
Returns:
a list of strings (including the initial row) corresponding to the ground
truth expanded to multiple annotation for evaluation with Open Images
Challenge 2018 metric.
"""
# Row header is expected to be exactly:
# ImageID,Source,LabelName,Confidence,XMin,XMax,YMin,YMax,IsOccluded,
# IsTruncated,IsGroupOf,IsDepiction,IsInside
cvs_row_splitted = csv_row.split(',')
assert len(cvs_row_splitted) == 13
result = [csv_row]
assert cvs_row_splitted[2] in self._hierarchy_keyed_child
parent_nodes = self._hierarchy_keyed_child[cvs_row_splitted[2]]
for parent_node in parent_nodes:
cvs_row_splitted[2] = parent_node
result.append(','.join(cvs_row_splitted))
return result
def expand_labels_from_csv(self, csv_row):
"""Expands a row containing bounding boxes from CSV file.
Args:
csv_row: a single row of Open Images released groundtruth file.
Returns:
a list of strings (including the initial row) corresponding to the ground
truth expanded to multiple annotation for evaluation with Open Images
Challenge 2018 metric.
"""
# Row header is expected to be exactly:
# ImageID,Source,LabelName,Confidence
cvs_row_splited = csv_row.split(',')
assert len(cvs_row_splited) == 4
result = [csv_row]
if int(cvs_row_splited[3]) == 1:
assert cvs_row_splited[2] in self._hierarchy_keyed_child
parent_nodes = self._hierarchy_keyed_child[cvs_row_splited[2]]
for parent_node in parent_nodes:
cvs_row_splited[2] = parent_node
result.append(','.join(cvs_row_splited))
else:
assert cvs_row_splited[2] in self._hierarchy_keyed_parent
child_nodes = self._hierarchy_keyed_parent[cvs_row_splited[2]]
for child_node in child_nodes:
cvs_row_splited[2] = child_node
result.append(','.join(cvs_row_splited))
return result
def main(parsed_args):
with open(parsed_args.json_hierarchy_file) as f:
hierarchy = json.load(f)
expansion_generator = OIDHierarchicalLabelsExpansion(hierarchy)
labels_file = False
if parsed_args.annotation_type == 2:
labels_file = True
elif parsed_args.annotation_type != 1:
print '--annotation_type expected value is 1 or 2.'
return -1
with open(parsed_args.input_annotations, 'r') as source:
with open(parsed_args.output_annotations, 'w') as target:
header = None
for line in source:
if not header:
header = line
continue
if labels_file:
expanded_lines = expansion_generator.expand_labels_from_csv(line)
else:
expanded_lines = expansion_generator.expand_boxes_from_csv(line)
expanded_lines = [header] + expanded_lines
target.writelines(expanded_lines)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Hierarchically expand annotations (excluding root node).')
parser.add_argument(
'--json_hierarchy_file',
required=True,
help='Path to the file containing label hierarchy in JSON format.')
parser.add_argument(
'--input_annotations',
required=True,
help="""Path to Open Images annotations file (either bounding boxes or
image-level labels).""")
parser.add_argument(
'--output_annotations',
required=True,
help="""Path to the output file.""")
parser.add_argument(
'--annotation_type',
type=int,
required=True,
help="""Type of the input annotations: 1 - boxes, 2 - image-level
labels"""
)
args = parser.parse_args()
main(args)