forked from cvat-ai/cvat
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Datumaro] Add generic accuracy checker model launcher (cvat-ai#1661)
* Refactor inference wrapper * Add accuracy checker launcher wrapper * t * rename method * Add importer for openvino launcher * Move openvino plugin to iecore * add generic AC launcher * Implement cli for AC launcher * move ac plugin dir * prevent tf reimport * Fix outputs conversion * t * add pytorch model example * Require config path in launcher * Clear extra whitespace
- Loading branch information
1 parent
020a8bf
commit 29ffc69
Showing
9 changed files
with
296 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
116 changes: 116 additions & 0 deletions
116
datumaro/datumaro/plugins/accuracy_checker_plugin/details/ac.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from datumaro.util.tf_util import import_tf | ||
import_tf() # prevent TF loading and potential interpeter crash | ||
|
||
from itertools import groupby | ||
|
||
from accuracy_checker.adapters import create_adapter | ||
from accuracy_checker.data_readers import DataRepresentation | ||
from accuracy_checker.launcher import InputFeeder, create_launcher | ||
from accuracy_checker.postprocessor import PostprocessingExecutor | ||
from accuracy_checker.preprocessor import PreprocessingExecutor | ||
from accuracy_checker.utils import extract_image_representations | ||
|
||
from datumaro.components.extractor import AnnotationType, LabelCategories | ||
|
||
from .representation import import_predictions | ||
|
||
|
||
class _FakeDataset: | ||
def __init__(self, metadata=None): | ||
self.metadata = metadata or {} | ||
|
||
class GenericAcLauncher: | ||
@staticmethod | ||
def from_config(config): | ||
launcher_config = config['launcher'] | ||
launcher = create_launcher(launcher_config) | ||
|
||
dataset = _FakeDataset() | ||
adapter_config = config.get('adapter') or launcher_config.get('adapter') | ||
label_config = adapter_config.get('labels') \ | ||
if isinstance(adapter_config, dict) else None | ||
if label_config: | ||
assert isinstance(label_config, (list, dict)) | ||
if isinstance(label_config, list): | ||
label_config = dict(enumerate(label_config)) | ||
|
||
dataset.metadata = {'label_map': { | ||
int(key): label for key, label in label_config.items() | ||
}} | ||
adapter = create_adapter(adapter_config, launcher, dataset) | ||
|
||
preproc_config = config.get('preprocessing') | ||
preproc = None | ||
if preproc_config: | ||
preproc = PreprocessingExecutor(preproc_config, | ||
dataset_meta=dataset.metadata, | ||
input_shapes=launcher.inputs_info_for_meta() | ||
) | ||
|
||
postproc_config = config.get('postprocessing') | ||
postproc = None | ||
if postproc_config: | ||
postproc = PostprocessingExecutor(postproc_config, | ||
dataset_meta=dataset.metadata, | ||
) | ||
|
||
return __class__(launcher, | ||
adapter=adapter, preproc=preproc, postproc=postproc) | ||
|
||
def __init__(self, launcher, adapter=None, | ||
preproc=None, postproc=None, input_feeder=None): | ||
self._launcher = launcher | ||
self._input_feeder = input_feeder or InputFeeder( | ||
launcher.config.get('inputs', []), launcher.inputs, | ||
launcher.fit_to_input, launcher.default_layout | ||
) | ||
self._adapter = adapter | ||
self._preproc = preproc | ||
self._postproc = postproc | ||
|
||
self._categories = self._init_categories() | ||
|
||
def launch_raw(self, inputs): | ||
ids = range(len(inputs)) | ||
inputs = [DataRepresentation(inp, identifier=id) | ||
for id, inp in zip(ids, inputs)] | ||
_, batch_meta = extract_image_representations(inputs) | ||
|
||
if self._preproc: | ||
inputs = self._preproc.process(inputs) | ||
|
||
inputs = self._input_feeder.fill_inputs(inputs) | ||
outputs = self._launcher.predict(inputs, batch_meta) | ||
|
||
if self._adapter: | ||
outputs = self._adapter.process(outputs, ids, batch_meta) | ||
|
||
if self._postproc: | ||
outputs = self._postproc.process(outputs) | ||
|
||
return outputs | ||
|
||
def launch(self, inputs): | ||
outputs = self.launch_raw(inputs) | ||
return [import_predictions(g) for _, g in | ||
groupby(outputs, key=lambda o: o.identifier)] | ||
|
||
def categories(self): | ||
return self._categories | ||
|
||
def _init_categories(self): | ||
if self._adapter is None or self._adapter.label_map is None: | ||
return None | ||
|
||
label_map = sorted(self._adapter.label_map.items(), key=lambda e: e[0]) | ||
|
||
label_cat = LabelCategories() | ||
for _, label in label_map: | ||
label_cat.add(label) | ||
|
||
return { AnnotationType.label: label_cat } |
62 changes: 62 additions & 0 deletions
62
datumaro/datumaro/plugins/accuracy_checker_plugin/details/representation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from datumaro.util.tf_util import import_tf | ||
import_tf() # prevent TF loading and potential interpeter crash | ||
|
||
import accuracy_checker.representation as ac | ||
|
||
import datumaro.components.extractor as dm | ||
from datumaro.util.annotation_tools import softmax | ||
|
||
def import_predictions(predictions): | ||
# Convert Accuracy checker predictions to Datumaro annotations | ||
|
||
anns = [] | ||
|
||
for pred in predictions: | ||
anns.extend(import_prediction(pred)) | ||
|
||
return anns | ||
|
||
def import_prediction(pred): | ||
if isinstance(pred, ac.ClassificationPrediction): | ||
scores = softmax(pred.scores) | ||
return (dm.Label(label_id, attributes={'score': float(score)}) | ||
for label_id, score in enumerate(scores)) | ||
elif isinstance(pred, ac.ArgMaxClassificationPrediction): | ||
return (dm.Label(int(pred.label)), ) | ||
elif isinstance(pred, ac.CharacterRecognitionPrediction): | ||
return (dm.Label(int(pred.label)), ) | ||
elif isinstance(pred, (ac.DetectionPrediction, ac.ActionDetectionPrediction)): | ||
return (dm.Bbox(x0, y0, x1 - x0, y1 - y0, int(label_id), | ||
attributes={'score': float(score)}) | ||
for label, score, x0, y0, x1, y1 in zip(pred.labels, pred.scores, | ||
pred.x_mins, pred.y_mins, pred.x_maxs, pred.y_maxs) | ||
) | ||
elif isinstance(pred, ac.DepthEstimationPrediction): | ||
return (dm.Mask(pred.depth_map), ) # 2d floating point mask | ||
# elif isinstance(pred, ac.HitRatioPrediction): | ||
# - | ||
elif isinstance(pred, ac.ImageInpaintingPrediction): | ||
return (dm.Mask(pred.value), ) # an image | ||
# elif isinstance(pred, ac.MultiLabelRecognitionPrediction): | ||
# - | ||
# elif isinstance(pred, ac.MachineTranslationPrediction): | ||
# - | ||
# elif isinstance(pred, ac.QuestionAnsweringPrediction): | ||
# - | ||
# elif isinstance(pred, ac.PoseEstimation3dPrediction): | ||
# - | ||
# elif isinstance(pred, ac.PoseEstimationPrediction): | ||
# - | ||
# elif isinstance(pred, ac.RegressionPrediction): | ||
# - | ||
else: | ||
raise NotImplementedError("Can't convert %s" % type(pred)) | ||
|
||
|
||
|
||
|
37 changes: 37 additions & 0 deletions
37
datumaro/datumaro/plugins/accuracy_checker_plugin/launcher.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import os.path as osp | ||
import yaml | ||
|
||
from datumaro.components.cli_plugin import CliPlugin | ||
from datumaro.components.launcher import Launcher | ||
|
||
from .details.ac import GenericAcLauncher as _GenericAcLauncher | ||
|
||
|
||
class AcLauncher(Launcher, CliPlugin): | ||
""" | ||
Generic model launcher with Accuracy Checker backend. | ||
""" | ||
|
||
@classmethod | ||
def build_cmdline_parser(cls, **kwargs): | ||
parser = super().build_cmdline_parser(**kwargs) | ||
parser.add_argument('-c', '--config', type=osp.abspath, required=True, | ||
help="Path to the launcher configuration file (.yml)") | ||
return parser | ||
|
||
def __init__(self, config, model_dir=None): | ||
model_dir = model_dir or '' | ||
with open(osp.join(model_dir, config), 'r') as f: | ||
config = yaml.safe_load(f) | ||
self._launcher = _GenericAcLauncher.from_config(config) | ||
|
||
def launch(self, inputs): | ||
return self._launcher.launch(inputs) | ||
|
||
def categories(self): | ||
return self._launcher.categories() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
launcher: | ||
framework: pytorch | ||
module: samplenet.SampLeNet | ||
python_path: '.' | ||
checkpoint: 'samplenet.pth' | ||
|
||
# launcher returns raw result, so it should be converted | ||
# to an appropriate representation with adapter | ||
adapter: | ||
type: classification | ||
labels: | ||
- label1 | ||
- label2 | ||
- label3 | ||
- label4 | ||
- label5 | ||
- label6 | ||
- label7 | ||
- label8 | ||
- label9 | ||
- label10 | ||
|
||
# list of preprocessing, applied to each image during validation | ||
# order of entries matters | ||
preprocessing: | ||
# resize input image to topology input size | ||
# you may specify size to which image should be resized | ||
# via dst_width, dst_height fields | ||
- type: resize | ||
size: 32 | ||
# topology is trained on RGB images, but Datumaro reads in BGR | ||
# so it must be converted to RGB | ||
- type: bgr_to_rgb | ||
# dataset mean and standard deviation | ||
- type: normalization | ||
mean: (125.307, 122.961, 113.8575) | ||
std: (51.5865, 50.847, 51.255) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Copyright (c) 2019 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class SampLeNet(nn.Module): | ||
def __init__(self): | ||
super(SampLeNet, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x |