From 5cd4abb18e1a40b86c5bf3d92c7c9217c810812d Mon Sep 17 00:00:00 2001 From: Emily Chun Date: Sat, 19 Jun 2021 15:04:26 +0900 Subject: [PATCH] Introduce Validator plugin type (#299) * Introduce Validator plugin type --- datumaro/cli/contexts/project/__init__.py | 34 +- datumaro/components/cli_plugin.py | 2 +- datumaro/components/environment.py | 8 +- datumaro/components/validator.py | 1264 +------------------- datumaro/plugins/validators.py | 1275 +++++++++++++++++++++ docs/user_manual.md | 26 +- tests/test_validator.py | 30 +- 7 files changed, 1349 insertions(+), 1290 deletions(-) create mode 100644 datumaro/plugins/validators.py diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index ff4dfb10bd69..e0c88520a1dd 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -18,7 +18,7 @@ from datumaro.components.project import \ PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project -from datumaro.components.validator import validate_annotations, TaskType +from datumaro.components.validator import Validator, TaskType from datumaro.util import error_rollback from ...util import (CliException, MultilineFormatter, add_subparser, @@ -801,8 +801,7 @@ def build_validate_parser(parser_ctor=argparse.ArgumentParser): """, formatter_class=MultilineFormatter) - parser.add_argument('task_type', - choices=[task_type.name for task_type in TaskType], + parser.add_argument('-t', '--task_type', choices=[task_type.name for task_type in TaskType], help="Task type for validation") parser.add_argument('-s', '--subset', dest='subset_name', default=None, help="Subset to validate (default: None)") @@ -816,19 +815,24 @@ def build_validate_parser(parser_ctor=argparse.ArgumentParser): def validate_command(args): project = load_project(args.project_dir) - task_type = args.task_type - subset_name = args.subset_name - dst_file_name = f'validation_results-{task_type}' + dst_file_name = f'report-{args.task_type}' dataset = project.make_dataset() - if subset_name is not None: - dataset = dataset.get_subset(subset_name) - dst_file_name += f'-{subset_name}' + if args.subset_name is not None: + dataset = dataset.get_subset(args.subset_name) + dst_file_name += f'-{args.subset_name}' + + try: + validator_type = project.env.validators[args.task_type] + except KeyError: + raise CliException("Validator type '%s' is not found" % args.task_type) extra_args = {} - from datumaro.components.validator import _Validator - extra_args = _Validator.parse_cmdline(args.extra_args) - validation_results = validate_annotations(dataset, task_type, **extra_args) + if hasattr(validator_type, 'parse_cmdline'): + extra_args = validator_type.parse_cmdline(args.extra_args) + + validator = validator_type(**extra_args) + report = validator.validate(dataset) def numpy_encoder(obj): if isinstance(obj, np.generic): @@ -843,12 +847,12 @@ def _make_serializable(d): if isinstance(val, dict): _make_serializable(val) - _make_serializable(validation_results) + _make_serializable(report) dst_file = generate_next_file_name(dst_file_name, ext='.json') log.info("Writing project validation results to '%s'" % dst_file) with open(dst_file, 'w') as f: - json.dump(validation_results, f, indent=4, sort_keys=True, + json.dump(report, f, indent=4, sort_keys=True, default=numpy_encoder) def build_parser(parser_ctor=argparse.ArgumentParser): @@ -875,4 +879,4 @@ def build_parser(parser_ctor=argparse.ArgumentParser): add_subparser(subparsers, 'stats', build_stats_parser) add_subparser(subparsers, 'validate', build_validate_parser) - return parser + return parser \ No newline at end of file diff --git a/datumaro/components/cli_plugin.py b/datumaro/components/cli_plugin.py index 702158aa7077..33375bf514a1 100644 --- a/datumaro/components/cli_plugin.py +++ b/datumaro/components/cli_plugin.py @@ -45,6 +45,6 @@ def parse_cmdline(cls, args=None): return args def remove_plugin_type(s): - for t in {'transform', 'extractor', 'converter', 'launcher', 'importer'}: + for t in {'transform', 'extractor', 'converter', 'launcher', 'importer', 'validator'}: s = s.replace('_' + t, '') return s diff --git a/datumaro/components/environment.py b/datumaro/components/environment.py index c27131a84143..1cee291b5d7b 100644 --- a/datumaro/components/environment.py +++ b/datumaro/components/environment.py @@ -149,6 +149,7 @@ def __init__(self, config=None): from datumaro.components.extractor import (Importer, Extractor, Transform) from datumaro.components.launcher import Launcher + from datumaro.components.validator import Validator self.extractors = PluginRegistry( builtin=select(builtin, Extractor), local=select(custom, Extractor) @@ -172,6 +173,10 @@ def __init__(self, config=None): builtin=select(builtin, Transform), local=select(custom, Transform) ) + self.validators = PluginRegistry( + builtin=select(builtin, Validator), + local=select(custom, Validator) + ) @staticmethod def _find_plugins(plugins_dir): @@ -262,7 +267,8 @@ def _load_plugins2(cls, plugins_dir): from datumaro.components.extractor import (Extractor, Importer, Transform) from datumaro.components.launcher import Launcher - types = [Extractor, Converter, Importer, Launcher, Transform] + from datumaro.components.validator import Validator + types = [Extractor, Converter, Importer, Launcher, Transform, Validator] return cls._load_plugins(plugins_dir, types) diff --git a/datumaro/components/validator.py b/datumaro/components/validator.py index 654dbc122165..b0db512a6e89 100644 --- a/datumaro/components/validator.py +++ b/datumaro/components/validator.py @@ -2,1274 +2,34 @@ # # SPDX-License-Identifier: MIT -from copy import deepcopy from enum import Enum, auto -from typing import Union - -import numpy as np +from typing import Dict, List from datumaro.components.dataset import IDataset -from datumaro.components.errors import (MissingLabelCategories, - MissingAnnotation, MultiLabelAnnotations, MissingAttribute, - UndefinedLabel, UndefinedAttribute, LabelDefinedButNotFound, - AttributeDefinedButNotFound, OnlyOneLabel, FewSamplesInLabel, - FewSamplesInAttribute, ImbalancedLabels, ImbalancedAttribute, - ImbalancedDistInLabel, ImbalancedDistInAttribute, - NegativeLength, InvalidValue, FarFromLabelMean, - FarFromAttrMean, OnlyOneAttributeValue) -from datumaro.components.extractor import AnnotationType, LabelCategories -from datumaro.components.cli_plugin import CliPlugin -from datumaro.util import parse_str_enum_value class Severity(Enum): warning = auto() error = auto() + class TaskType(Enum): classification = auto() detection = auto() segmentation = auto() -class _Validator(CliPlugin): - # statistics templates - numerical_stat_template = { - 'items_far_from_mean': {}, - 'mean': None, - 'stdev': None, - 'min': None, - 'max': None, - 'median': None, - 'histogram': { - 'bins': [], - 'counts': [], - }, - 'distribution': np.array([]) - } - - """ - A base class for task-specific validators. - - Attributes - ---------- - task_type : str or TaskType - task type (ie. classification, detection, segmentation) - - Methods - ------- - compute_statistics(dataset): - Computes various statistics of the dataset based on task type. - generate_reports(stats): - Abstract method that must be implemented in a subclass. - """ - - @classmethod - def build_cmdline_parser(cls, **kwargs): - parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, - help="Threshold for giving a warning for minimum number of" - "samples per class") - parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, - help="Threshold for giving data imbalance warning;" - "IR(imbalance ratio) = majority/minority") - parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, - help="Threshold for giving a warning that data is far from mean;" - "A constant used to define mean +/- k * standard deviation;") - parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, - help="Threshold for giving a warning for bounding box imbalance;" - "Dominace_ratio = ratio of Top-k bin to total in histogram;") - parser.add_argument('-k', '--topk_bins', default=0.1, type=float, - help="Ratio of bins with the highest number of data" - "to total bins in the histogram; [0, 1]; 0.1 = 10%;") - return parser - - def __init__(self, task_type, few_samples_thr=None, - imbalance_ratio_thr=None, far_from_mean_thr=None, - dominance_ratio_thr=None, topk_bins=None): - """ - Validator - - Parameters - --------------- - few_samples_thr: int - minimum number of samples per class - warn user when samples per class is less than threshold - imbalance_ratio_thr: int - ratio of majority attribute to minority attribute - warn user when annotations are unevenly distributed - far_from_mean_thr: float - constant used to define mean +/- m * stddev - warn user when there are too big or small values - dominance_ratio_thr: float - ratio of Top-k bin to total - warn user when dominance ratio is over threshold - topk_bins: float - ratio of selected bins with most item number to total bins - warn user when values are not evenly distributed - """ - self.task_type = parse_str_enum_value(task_type, TaskType, - default=TaskType.classification) - - if self.task_type == TaskType.classification: - self.ann_types = {AnnotationType.label} - self.str_ann_type = "label" - elif self.task_type == TaskType.detection: - self.ann_types = {AnnotationType.bbox} - self.str_ann_type = "bounding box" - elif self.task_type == TaskType.segmentation: - self.ann_types = {AnnotationType.mask, AnnotationType.polygon} - self.str_ann_type = "mask or polygon" - - self.few_samples_thr = few_samples_thr - self.imbalance_ratio_thr = imbalance_ratio_thr - self.far_from_mean_thr = far_from_mean_thr - self.dominance_thr = dominance_ratio_thr - self.topk_bins_ratio = topk_bins - - def _compute_common_statistics(self, dataset): - defined_attr_template = { - 'items_missing_attribute': [], - 'distribution': {} - } - undefined_attr_template = { - 'items_with_undefined_attr': [], - 'distribution': {} - } - undefined_label_template = { - 'count': 0, - 'items_with_undefined_label': [], - } - - stats = { - 'label_distribution': { - 'defined_labels': {}, - 'undefined_labels': {}, - }, - 'attribute_distribution': { - 'defined_attributes': {}, - 'undefined_attributes': {} - }, - } - stats['total_ann_count'] = 0 - stats['items_missing_annotation'] = [] - - label_dist = stats['label_distribution'] - attr_dist = stats['attribute_distribution'] - defined_label_dist = label_dist['defined_labels'] - defined_attr_dist = attr_dist['defined_attributes'] - undefined_label_dist = label_dist['undefined_labels'] - undefined_attr_dist = attr_dist['undefined_attributes'] - - label_categories = dataset.categories().get(AnnotationType.label, - LabelCategories()) - base_valid_attrs = label_categories.attributes - - for category in label_categories: - defined_label_dist[category.name] = 0 - - filtered_anns = [] - for item in dataset: - item_key = (item.id, item.subset) - annotations = [] - for ann in item.annotations: - if ann.type in self.ann_types: - annotations.append(ann) - ann_count = len(annotations) - filtered_anns.append((item_key, annotations)) - - if ann_count == 0: - stats['items_missing_annotation'].append(item_key) - stats['total_ann_count'] += ann_count - - for ann in annotations: - if not 0 <= ann.label < len(label_categories): - label_name = ann.label - - label_stats = undefined_label_dist.setdefault( - ann.label, deepcopy(undefined_label_template)) - label_stats['items_with_undefined_label'].append( - item_key) - - label_stats['count'] += 1 - valid_attrs = set() - missing_attrs = set() - else: - label_name = label_categories[ann.label].name - defined_label_dist[label_name] += 1 - - defined_attr_stats = defined_attr_dist.setdefault( - label_name, {}) - - valid_attrs = base_valid_attrs.union( - label_categories[ann.label].attributes) - ann_attrs = getattr(ann, 'attributes', {}).keys() - missing_attrs = valid_attrs.difference(ann_attrs) - - for attr in valid_attrs: - defined_attr_stats.setdefault( - attr, deepcopy(defined_attr_template)) - - for attr in missing_attrs: - attr_dets = defined_attr_stats[attr] - attr_dets['items_missing_attribute'].append( - item_key) - - for attr, value in ann.attributes.items(): - if attr not in valid_attrs: - undefined_attr_stats = \ - undefined_attr_dist.setdefault( - label_name, {}) - attr_dets = undefined_attr_stats.setdefault( - attr, deepcopy(undefined_attr_template)) - attr_dets['items_with_undefined_attr'].append( - item_key) - else: - attr_dets = defined_attr_stats[attr] - - attr_dets['distribution'].setdefault(str(value), 0) - attr_dets['distribution'][str(value)] += 1 - - return stats, filtered_anns - - @staticmethod - def _update_prop_distributions(curr_prop_stats, target_stats): - for prop, val in curr_prop_stats.items(): - prop_stats = target_stats[prop] - prop_dist = prop_stats['distribution'] - prop_stats['distribution'] = np.append(prop_dist, val) - - @staticmethod - def _compute_prop_stats_from_dist(dist_by_label, dist_by_attr): - for label_name, stats in dist_by_label.items(): - prop_stats_list = list(stats.values()) - attr_label = dist_by_attr.get(label_name, {}) - for vals in attr_label.values(): - for val_stats in vals.values(): - prop_stats_list += list(val_stats.values()) - - for prop_stats in prop_stats_list: - prop_dist = prop_stats.pop('distribution', []) - if len(prop_dist) > 0: - prop_stats['mean'] = np.mean(prop_dist) - prop_stats['stdev'] = np.std(prop_dist) - prop_stats['min'] = np.min(prop_dist) - prop_stats['max'] = np.max(prop_dist) - prop_stats['median'] = np.median(prop_dist) - - counts, bins = np.histogram(prop_dist) - prop_stats['histogram']['bins'] = bins.tolist() - prop_stats['histogram']['counts'] = counts.tolist() - - def _compute_far_from_mean(self, prop_stats, val, item_key, ann): - def _far_from_mean(val, mean, stdev): - thr = self.far_from_mean_thr - return val > mean + (thr * stdev) or val < mean - (thr * stdev) - - mean = prop_stats['mean'] - stdev = prop_stats['stdev'] - - if _far_from_mean(val, mean, stdev): - items_far_from_mean = prop_stats['items_far_from_mean'] - far_from_mean = items_far_from_mean.setdefault( - item_key, {}) - far_from_mean[ann.id] = val - - def compute_statistics(self, dataset): - """ - Computes statistics of the dataset based on task type. - - Parameters - ---------- - dataset : IDataset object - - Returns - ------- - stats (dict): A dict object containing statistics of the dataset. - """ - return NotImplementedError - - def _check_missing_label_categories(self, stats): - validation_reports = [] - - if len(stats['label_distribution']['defined_labels']) == 0: - validation_reports += self._generate_validation_report( - MissingLabelCategories, Severity.error) - - return validation_reports - - def _check_missing_annotation(self, stats): - validation_reports = [] - - items_missing = stats['items_missing_annotation'] - for item_id, item_subset in items_missing: - validation_reports += self._generate_validation_report( - MissingAnnotation, Severity.warning, item_id, item_subset, - self.str_ann_type) - - return validation_reports - - def _check_missing_attribute(self, label_name, attr_name, attr_dets): - validation_reports = [] - - items_missing_attr = attr_dets['items_missing_attribute'] - for item_id, item_subset in items_missing_attr: - details = (item_subset, label_name, attr_name) - validation_reports += self._generate_validation_report( - MissingAttribute, Severity.warning, item_id, *details) - - return validation_reports - - def _check_undefined_label(self, label_name, label_stats): - validation_reports = [] - - items_with_undefined_label = label_stats['items_with_undefined_label'] - for item_id, item_subset in items_with_undefined_label: - details = (item_subset, label_name) - validation_reports += self._generate_validation_report( - UndefinedLabel, Severity.error, item_id, *details) - - return validation_reports - - def _check_undefined_attribute(self, label_name, attr_name, attr_dets): - validation_reports = [] - - items_with_undefined_attr = attr_dets['items_with_undefined_attr'] - for item_id, item_subset in items_with_undefined_attr: - details = (item_subset, label_name, attr_name) - validation_reports += self._generate_validation_report( - UndefinedAttribute, Severity.error, item_id, *details) - - return validation_reports - - def _check_label_defined_but_not_found(self, stats): - validation_reports = [] - count_by_defined_labels = stats['label_distribution']['defined_labels'] - labels_not_found = [label_name - for label_name, count in count_by_defined_labels.items() - if count == 0] - - for label_name in labels_not_found: - validation_reports += self._generate_validation_report( - LabelDefinedButNotFound, Severity.warning, label_name) - - return validation_reports - - def _check_attribute_defined_but_not_found(self, label_name, attr_stats): - validation_reports = [] - attrs_not_found = [attr_name - for attr_name, attr_dets in attr_stats.items() - if len(attr_dets['distribution']) == 0] - - for attr_name in attrs_not_found: - details = (label_name, attr_name) - validation_reports += self._generate_validation_report( - AttributeDefinedButNotFound, Severity.warning, *details) - - return validation_reports - - def _check_only_one_label(self, stats): - validation_reports = [] - count_by_defined_labels = stats['label_distribution']['defined_labels'] - labels_found = [label_name - for label_name, count in count_by_defined_labels.items() - if count > 0] - - if len(labels_found) == 1: - validation_reports += self._generate_validation_report( - OnlyOneLabel, Severity.warning, labels_found[0]) - - return validation_reports - - def _check_only_one_attribute_value(self, label_name, attr_name, attr_dets): - validation_reports = [] - values = list(attr_dets['distribution'].keys()) - - if len(values) == 1: - details = (label_name, attr_name, values[0]) - validation_reports += self._generate_validation_report( - OnlyOneAttributeValue, Severity.warning, *details) - - return validation_reports - - def _check_few_samples_in_label(self, stats): - validation_reports = [] - thr = self.few_samples_thr - - defined_label_dist = stats['label_distribution']['defined_labels'] - labels_with_few_samples = [(label_name, count) - for label_name, count in defined_label_dist.items() - if 0 < count <= thr] - - for label_name, count in labels_with_few_samples: - validation_reports += self._generate_validation_report( - FewSamplesInLabel, Severity.warning, label_name, count) - - return validation_reports - - def _check_few_samples_in_attribute(self, label_name, - attr_name, attr_dets): - validation_reports = [] - thr = self.few_samples_thr - - attr_values_with_few_samples = [(attr_value, count) - for attr_value, count in attr_dets['distribution'].items() - if count <= thr] - - for attr_value, count in attr_values_with_few_samples: - details = (label_name, attr_name, attr_value, count) - validation_reports += self._generate_validation_report( - FewSamplesInAttribute, Severity.warning, *details) - - return validation_reports - - def _check_imbalanced_labels(self, stats): - validation_reports = [] - thr = self.imbalance_ratio_thr - - defined_label_dist = stats['label_distribution']['defined_labels'] - count_by_defined_labels = [count - for label, count in defined_label_dist.items()] - - if len(count_by_defined_labels) == 0: - return validation_reports - - count_max = np.max(count_by_defined_labels) - count_min = np.min(count_by_defined_labels) - balance = count_max / count_min if count_min > 0 else float('inf') - if balance >= thr: - validation_reports += self._generate_validation_report( - ImbalancedLabels, Severity.warning) - - return validation_reports - - def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets): - validation_reports = [] - thr = self.imbalance_ratio_thr - - count_by_defined_attr = list(attr_dets['distribution'].values()) - if len(count_by_defined_attr) == 0: - return validation_reports - - count_max = np.max(count_by_defined_attr) - count_min = np.min(count_by_defined_attr) - balance = count_max / count_min if count_min > 0 else float('inf') - if balance >= thr: - validation_reports += self._generate_validation_report( - ImbalancedAttribute, Severity.warning, label_name, attr_name) - - return validation_reports - - def _check_imbalanced_dist_in_label(self, label_name, label_stats): - validation_reports = [] - thr = self.dominance_thr - topk_ratio = self.topk_bins_ratio - - for prop, prop_stats in label_stats.items(): - value_counts = prop_stats['histogram']['counts'] - n_bucket = len(value_counts) - if n_bucket < 2: - continue - topk = max(1, int(np.around(n_bucket * topk_ratio))) - - if topk > 0: - topk_values = np.sort(value_counts)[-topk:] - ratio = np.sum(topk_values) / np.sum(value_counts) - if ratio >= thr: - details = (label_name, f"{self.str_ann_type} {prop}") - validation_reports += self._generate_validation_report( - ImbalancedDistInLabel, Severity.warning, *details) - - return validation_reports - - def _check_imbalanced_dist_in_attr(self, label_name, attr_name, attr_stats): - validation_reports = [] - thr = self.dominance_thr - topk_ratio = self.topk_bins_ratio - - for attr_value, value_stats in attr_stats.items(): - for prop, prop_stats in value_stats.items(): - value_counts = prop_stats['histogram']['counts'] - n_bucket = len(value_counts) - if n_bucket < 2: - continue - topk = max(1, int(np.around(n_bucket * topk_ratio))) - - if topk > 0: - topk_values = np.sort(value_counts)[-topk:] - ratio = np.sum(topk_values) / np.sum(value_counts) - if ratio >= thr: - details = (label_name, attr_name, attr_value, - f"{self.str_ann_type} {prop}") - validation_reports += self._generate_validation_report( - ImbalancedDistInAttribute, - Severity.warning, - *details - ) - - return validation_reports - - def _check_invalid_value(self, stats): - validation_reports = [] - - items_w_invalid_val = stats['items_with_invalid_value'] - for item_dets, anns_w_invalid_val in items_w_invalid_val.items(): - item_id, item_subset = item_dets - for ann_id, props in anns_w_invalid_val.items(): - for prop in props: - details = (item_subset, ann_id, - f"{self.str_ann_type} {prop}") - validation_reports += self._generate_validation_report( - InvalidValue, Severity.error, item_id, *details) - - return validation_reports - - def _check_far_from_label_mean(self, label_name, label_stats): - validation_reports = [] - - for prop, prop_stats in label_stats.items(): - items_far_from_mean = prop_stats['items_far_from_mean'] - if prop_stats['mean'] is not None: - mean = round(prop_stats['mean'], 2) - - for item_dets, anns_far in items_far_from_mean.items(): - item_id, item_subset = item_dets - for ann_id, val in anns_far.items(): - val = round(val, 2) - details = (item_subset, label_name, ann_id, - f"{self.str_ann_type} {prop}", mean, val) - validation_reports += self._generate_validation_report( - FarFromLabelMean, Severity.warning, item_id, *details) - - return validation_reports - - def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats): - validation_reports = [] - - for attr_value, value_stats in attr_stats.items(): - for prop, prop_stats in value_stats.items(): - items_far_from_mean = prop_stats['items_far_from_mean'] - if prop_stats['mean'] is not None: - mean = round(prop_stats['mean'], 2) - - for item_dets, anns_far in items_far_from_mean.items(): - item_id, item_subset = item_dets - for ann_id, val in anns_far.items(): - val = round(val, 2) - details = (item_subset, label_name, ann_id, attr_name, - attr_value, f"{self.str_ann_type} {prop}", - mean, val) - validation_reports += self._generate_validation_report( - FarFromAttrMean, - Severity.warning, - item_id, - *details - ) - - return validation_reports - - def generate_reports(self, stats): - raise NotImplementedError('Should be implemented in a subclass.') - - def _generate_validation_report(self, error, *args, **kwargs): - return [error(*args, **kwargs)] - - -class ClassificationValidator(_Validator): - """ - A validator class for classification tasks. - """ - - def __init__(self, few_samples_thr, imbalance_ratio_thr, - far_from_mean_thr, dominance_ratio_thr, topk_bins): - super().__init__(task_type=TaskType.classification, - few_samples_thr=few_samples_thr, - imbalance_ratio_thr=imbalance_ratio_thr, - far_from_mean_thr=far_from_mean_thr, - dominance_ratio_thr=dominance_ratio_thr, topk_bins=topk_bins) - - def _check_multi_label_annotations(self, stats): - validation_reports = [] - - items_with_multiple_labels = stats['items_with_multiple_labels'] - for item_id, item_subset in items_with_multiple_labels: - validation_reports += self._generate_validation_report( - MultiLabelAnnotations, Severity.error, item_id, item_subset) - - return validation_reports - - def compute_statistics(self, dataset): - """ - Computes statistics of the dataset for the classification task. - - Parameters - ---------- - dataset : IDataset object - - Returns - ------- - stats (dict): A dict object containing statistics of the dataset. - """ - - stats, filtered_anns = self._compute_common_statistics(dataset) - - stats['items_with_multiple_labels'] = [] - - for item_key, anns in filtered_anns: - ann_count = len(anns) - if ann_count > 1: - stats['items_with_multiple_labels'].append(item_key) - - return stats - - def generate_reports(self, stats): - """ - Validates the dataset for classification tasks based on its statistics. - - Parameters - ---------- - dataset : IDataset object - stats: Dict object - - Returns - ------- - reports (list): List of validation reports (DatasetValidationError). - """ - - reports = [] - - reports += self._check_missing_label_categories(stats) - reports += self._check_missing_annotation(stats) - reports += self._check_multi_label_annotations(stats) - reports += self._check_label_defined_but_not_found(stats) - reports += self._check_only_one_label(stats) - reports += self._check_few_samples_in_label(stats) - reports += self._check_imbalanced_labels(stats) - - label_dist = stats['label_distribution'] - attr_dist = stats['attribute_distribution'] - defined_attr_dist = attr_dist['defined_attributes'] - undefined_label_dist = label_dist['undefined_labels'] - undefined_attr_dist = attr_dist['undefined_attributes'] - - defined_labels = defined_attr_dist.keys() - for label_name in defined_labels: - attr_stats = defined_attr_dist[label_name] - - reports += self._check_attribute_defined_but_not_found( - label_name, attr_stats) - - for attr_name, attr_dets in attr_stats.items(): - reports += self._check_few_samples_in_attribute( - label_name, attr_name, attr_dets) - reports += self._check_imbalanced_attribute( - label_name, attr_name, attr_dets) - reports += self._check_only_one_attribute_value( - label_name, attr_name, attr_dets) - reports += self._check_missing_attribute( - label_name, attr_name, attr_dets) - - for label_name, label_stats in undefined_label_dist.items(): - reports += self._check_undefined_label(label_name, label_stats) - - for label_name, attr_stats in undefined_attr_dist.items(): - for attr_name, attr_dets in attr_stats.items(): - reports += self._check_undefined_attribute( - label_name, attr_name, attr_dets) - - return reports - - -class DetectionValidator(_Validator): - """ - A validator class for detection tasks. - """ - def __init__(self, few_samples_thr, imbalance_ratio_thr, - far_from_mean_thr, dominance_ratio_thr, topk_bins): - super().__init__(task_type=TaskType.detection, - few_samples_thr=few_samples_thr, - imbalance_ratio_thr=imbalance_ratio_thr, - far_from_mean_thr=far_from_mean_thr, - dominance_ratio_thr=dominance_ratio_thr, topk_bins=topk_bins) - - def _check_negative_length(self, stats): - validation_reports = [] - - items_w_neg_len = stats['items_with_negative_length'] - for item_dets, anns_w_neg_len in items_w_neg_len.items(): - item_id, item_subset = item_dets - for ann_id, props in anns_w_neg_len.items(): - for prop, val in props.items(): - val = round(val, 2) - details = (item_subset, ann_id, - f"{self.str_ann_type} {prop}", val) - validation_reports += self._generate_validation_report( - NegativeLength, Severity.error, item_id, *details) - - return validation_reports - - def compute_statistics(self, dataset): - """ - Computes statistics of the dataset for the detection task. - - Parameters - ---------- - dataset : IDataset object - - Returns - ------- - stats (dict): A dict object containing statistics of the dataset. - """ - - stats, filtered_anns = self._compute_common_statistics(dataset) - - # detection-specific - bbox_template = { - 'width': deepcopy(self.numerical_stat_template), - 'height': deepcopy(self.numerical_stat_template), - 'area(wxh)': deepcopy(self.numerical_stat_template), - 'ratio(w/h)': deepcopy(self.numerical_stat_template), - 'short': deepcopy(self.numerical_stat_template), - 'long': deepcopy(self.numerical_stat_template) - } - - stats['items_with_negative_length'] = {} - stats['items_with_invalid_value'] = {} - stats['bbox_distribution_in_label'] = {} - stats['bbox_distribution_in_attribute'] = {} - stats['bbox_distribution_in_dataset_item'] = {} - - dist_by_label = stats['bbox_distribution_in_label'] - dist_by_attr = stats['bbox_distribution_in_attribute'] - bbox_dist_in_item = stats['bbox_distribution_in_dataset_item'] - items_w_neg_len = stats['items_with_negative_length'] - items_w_invalid_val = stats['items_with_invalid_value'] - - def _generate_ann_bbox_info(_x, _y, _w, _h, area, - ratio, _short, _long): - return { - 'x': _x, - 'y': _y, - 'width': _w, - 'height': _h, - 'area(wxh)': area, - 'ratio(w/h)': ratio, - 'short': _short, - 'long': _long, - } - - def _update_bbox_stats_by_label(item_key, ann, bbox_label_stats): - bbox_has_error = False - - _x, _y, _w, _h = ann.get_bbox() - area = ann.get_area() - - if _h != 0 and _h != float('inf'): - ratio = _w / _h - else: - ratio = float('nan') - - _short = _w if _w < _h else _h - _long = _w if _w > _h else _h - - ann_bbox_info = _generate_ann_bbox_info( - _x, _y, _w, _h, area, ratio, _short, _long) - - for prop, val in ann_bbox_info.items(): - if val == float('inf') or np.isnan(val): - bbox_has_error = True - anns_w_invalid_val = items_w_invalid_val.setdefault( - item_key, {}) - invalid_props = anns_w_invalid_val.setdefault( - ann.id, []) - invalid_props.append(prop) - - for prop in ['width', 'height']: - val = ann_bbox_info[prop] - if val < 1: - bbox_has_error = True - anns_w_neg_len = items_w_neg_len.setdefault( - item_key, {}) - neg_props = anns_w_neg_len.setdefault(ann.id, {}) - neg_props[prop] = val - - if not bbox_has_error: - ann_bbox_info.pop('x') - ann_bbox_info.pop('y') - self._update_prop_distributions(ann_bbox_info, bbox_label_stats) - - return ann_bbox_info, bbox_has_error - - label_categories = dataset.categories().get(AnnotationType.label, - LabelCategories()) - base_valid_attrs = label_categories.attributes - - for item_key, annotations in filtered_anns: - ann_count = len(annotations) - - bbox_dist_in_item[item_key] = ann_count - - for ann in annotations: - if not 0 <= ann.label < len(label_categories): - label_name = ann.label - valid_attrs = set() - else: - label_name = label_categories[ann.label].name - valid_attrs = base_valid_attrs.union( - label_categories[ann.label].attributes) - - bbox_label_stats = dist_by_label.setdefault( - label_name, deepcopy(bbox_template)) - ann_bbox_info, bbox_has_error = \ - _update_bbox_stats_by_label( - item_key, ann, bbox_label_stats) - - for attr, value in ann.attributes.items(): - if attr in valid_attrs: - bbox_attr_label = dist_by_attr.setdefault( - label_name, {}) - bbox_attr_stats = bbox_attr_label.setdefault( - attr, {}) - bbox_val_stats = bbox_attr_stats.setdefault( - str(value), deepcopy(bbox_template)) - - if not bbox_has_error: - self._update_prop_distributions( - ann_bbox_info, bbox_val_stats) - - # Compute prop stats from distribution - self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr) - - def _is_valid_ann(item_key, ann): - has_defined_label = 0 <= ann.label < len(label_categories) - if not has_defined_label: - return False - - bbox_has_neg_len = ann.id in items_w_neg_len.get( - item_key, {}) - bbox_has_invalid_val = ann.id in items_w_invalid_val.get( - item_key, {}) - return not (bbox_has_neg_len or bbox_has_invalid_val) - - def _update_props_far_from_mean(item_key, ann): - valid_attrs = base_valid_attrs.union( - label_categories[ann.label].attributes) - label_name = label_categories[ann.label].name - bbox_label_stats = dist_by_label[label_name] - - _x, _y, _w, _h = ann.get_bbox() - area = ann.get_area() - ratio = _w / _h - _short = _w if _w < _h else _h - _long = _w if _w > _h else _h - - ann_bbox_info = _generate_ann_bbox_info( - _x, _y, _w, _h, area, ratio, _short, _long) - ann_bbox_info.pop('x') - ann_bbox_info.pop('y') - - for prop, val in ann_bbox_info.items(): - prop_stats = bbox_label_stats[prop] - self._compute_far_from_mean(prop_stats, val, item_key, ann) - - for attr, value in ann.attributes.items(): - if attr in valid_attrs: - bbox_attr_stats = dist_by_attr[label_name][attr] - bbox_val_stats = bbox_attr_stats[str(value)] - - for prop, val in ann_bbox_info.items(): - prop_stats = bbox_val_stats[prop] - self._compute_far_from_mean(prop_stats, val, - item_key, ann) - - for item_key, annotations in filtered_anns: - for ann in annotations: - if _is_valid_ann(item_key, ann): - _update_props_far_from_mean(item_key, ann) - - return stats - - def generate_reports(self, stats): - """ - Validates the dataset for detection tasks based on its statistics. - - Parameters - ---------- - dataset : IDataset object - stats : Dict object - - Returns - ------- - reports (list): List of validation reports (DatasetValidationError). - """ - - reports = [] - - reports += self._check_missing_label_categories(stats) - reports += self._check_missing_annotation(stats) - reports += self._check_label_defined_but_not_found(stats) - reports += self._check_only_one_label(stats) - reports += self._check_few_samples_in_label(stats) - reports += self._check_imbalanced_labels(stats) - reports += self._check_negative_length(stats) - reports += self._check_invalid_value(stats) - - label_dist = stats['label_distribution'] - attr_dist = stats['attribute_distribution'] - defined_attr_dist = attr_dist['defined_attributes'] - undefined_label_dist = label_dist['undefined_labels'] - undefined_attr_dist = attr_dist['undefined_attributes'] - - dist_by_label = stats['bbox_distribution_in_label'] - dist_by_attr = stats['bbox_distribution_in_attribute'] - - defined_labels = defined_attr_dist.keys() - for label_name in defined_labels: - attr_stats = defined_attr_dist[label_name] - - reports += self._check_attribute_defined_but_not_found( - label_name, attr_stats) - - for attr_name, attr_dets in attr_stats.items(): - reports += self._check_few_samples_in_attribute( - label_name, attr_name, attr_dets) - reports += self._check_imbalanced_attribute( - label_name, attr_name, attr_dets) - reports += self._check_only_one_attribute_value( - label_name, attr_name, attr_dets) - reports += self._check_missing_attribute( - label_name, attr_name, attr_dets) - - bbox_label_stats = dist_by_label[label_name] - bbox_attr_label = dist_by_attr.get(label_name, {}) - - reports += self._check_far_from_label_mean( - label_name, bbox_label_stats) - reports += self._check_imbalanced_dist_in_label( - label_name, bbox_label_stats) - - for attr_name, bbox_attr_stats in bbox_attr_label.items(): - reports += self._check_far_from_attr_mean( - label_name, attr_name, bbox_attr_stats) - reports += self._check_imbalanced_dist_in_attr( - label_name, attr_name, bbox_attr_stats) - - for label_name, label_stats in undefined_label_dist.items(): - reports += self._check_undefined_label(label_name, label_stats) - - for label_name, attr_stats in undefined_attr_dist.items(): - for attr_name, attr_dets in attr_stats.items(): - reports += self._check_undefined_attribute( - label_name, attr_name, attr_dets) - - return reports - - -class SegmentationValidator(_Validator): - """ - A validator class for (instance) segmentation tasks. - """ - - def __init__(self, few_samples_thr, imbalance_ratio_thr, - far_from_mean_thr, dominance_ratio_thr, topk_bins): - super().__init__(task_type=TaskType.segmentation, - few_samples_thr=few_samples_thr, - imbalance_ratio_thr=imbalance_ratio_thr, - far_from_mean_thr=far_from_mean_thr, - dominance_ratio_thr=dominance_ratio_thr, topk_bins=topk_bins) - - def compute_statistics(self, dataset): - """ - Computes statistics of the dataset for the segmentation task. - - Parameters - ---------- - dataset : IDataset object - - Returns - ------- - stats (dict): A dict object containing statistics of the dataset. - """ - - stats, filtered_anns = self._compute_common_statistics(dataset) - - # segmentation-specific - mask_template = { - 'area': deepcopy(self.numerical_stat_template), - 'width': deepcopy(self.numerical_stat_template), - 'height': deepcopy(self.numerical_stat_template) - } - - stats['items_with_invalid_value'] = {} - stats['mask_distribution_in_label'] = {} - stats['mask_distribution_in_attribute'] = {} - stats['mask_distribution_in_dataset_item'] = {} - - dist_by_label = stats['mask_distribution_in_label'] - dist_by_attr = stats['mask_distribution_in_attribute'] - mask_dist_in_item = stats['mask_distribution_in_dataset_item'] - items_w_invalid_val = stats['items_with_invalid_value'] - - def _generate_ann_mask_info(area, _w, _h): - return { - 'area': area, - 'width': _w, - 'height': _h, - } - - def _update_mask_stats_by_label(item_key, ann, mask_label_stats): - mask_has_error = False - - _x, _y, _w, _h = ann.get_bbox() - - # Detete the following block when #226 is resolved - # https://github.com/openvinotoolkit/datumaro/issues/226 - if ann.type == AnnotationType.mask: - _w += 1 - _h += 1 - - area = ann.get_area() - - ann_mask_info = _generate_ann_mask_info(area, _w, _h) - - for prop, val in ann_mask_info.items(): - if val == float('inf') or np.isnan(val): - mask_has_error = True - anns_w_invalid_val = items_w_invalid_val.setdefault( - item_key, {}) - invalid_props = anns_w_invalid_val.setdefault( - ann.id, []) - invalid_props.append(prop) - - if not mask_has_error: - self._update_prop_distributions(ann_mask_info, mask_label_stats) - - return ann_mask_info, mask_has_error - - label_categories = dataset.categories().get(AnnotationType.label, - LabelCategories()) - base_valid_attrs = label_categories.attributes - - for item_key, annotations in filtered_anns: - ann_count = len(annotations) - mask_dist_in_item[item_key] = ann_count - - for ann in annotations: - if not 0 <= ann.label < len(label_categories): - label_name = ann.label - valid_attrs = set() - else: - label_name = label_categories[ann.label].name - valid_attrs = base_valid_attrs.union( - label_categories[ann.label].attributes) - - mask_label_stats = dist_by_label.setdefault( - label_name, deepcopy(mask_template)) - ann_mask_info, mask_has_error = \ - _update_mask_stats_by_label( - item_key, ann, mask_label_stats) - - for attr, value in ann.attributes.items(): - if attr in valid_attrs: - mask_attr_label = dist_by_attr.setdefault( - label_name, {}) - mask_attr_stats = mask_attr_label.setdefault( - attr, {}) - mask_val_stats = mask_attr_stats.setdefault( - str(value), deepcopy(mask_template)) - - if not mask_has_error: - self._update_prop_distributions( - ann_mask_info, mask_val_stats) - - # compute prop stats from dist. - self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr) - - def _is_valid_ann(item_key, ann): - has_defined_label = 0 <= ann.label < len(label_categories) - if not has_defined_label: - return False - - mask_has_invalid_val = ann.id in items_w_invalid_val.get( - item_key, {}) - return not mask_has_invalid_val - - def _update_props_far_from_mean(item_key, ann): - valid_attrs = base_valid_attrs.union( - label_categories[ann.label].attributes) - label_name = label_categories[ann.label].name - mask_label_stats = dist_by_label[label_name] - - _x, _y, _w, _h = ann.get_bbox() - - # Detete the following block when #226 is resolved - # https://github.com/openvinotoolkit/datumaro/issues/226 - if ann.type == AnnotationType.mask: - _w += 1 - _h += 1 - area = ann.get_area() - - ann_mask_info = _generate_ann_mask_info(area, _w, _h) - - for prop, val in ann_mask_info.items(): - prop_stats = mask_label_stats[prop] - self._compute_far_from_mean(prop_stats, val, item_key, ann) - - for attr, value in ann.attributes.items(): - if attr in valid_attrs: - mask_attr_stats = dist_by_attr[label_name][attr] - mask_val_stats = mask_attr_stats[str(value)] - - for prop, val in ann_mask_info.items(): - prop_stats = mask_val_stats[prop] - self._compute_far_from_mean(prop_stats, val, - item_key, ann) - - for item_key, annotations in filtered_anns: - for ann in annotations: - if _is_valid_ann(item_key, ann): - _update_props_far_from_mean(item_key, ann) - - return stats - - def generate_reports(self, stats): - """ - Validates the dataset for segmentation tasks based on its statistics. - - Parameters - ---------- - dataset : IDataset object - stats : Dict object - - Returns - ------- - reports (list): List of validation reports (DatasetValidationError). - """ - - reports = [] - - reports += self._check_missing_label_categories(stats) - reports += self._check_missing_annotation(stats) - reports += self._check_label_defined_but_not_found(stats) - reports += self._check_only_one_label(stats) - reports += self._check_few_samples_in_label(stats) - reports += self._check_imbalanced_labels(stats) - reports += self._check_invalid_value(stats) - - label_dist = stats['label_distribution'] - attr_dist = stats['attribute_distribution'] - defined_attr_dist = attr_dist['defined_attributes'] - undefined_label_dist = label_dist['undefined_labels'] - undefined_attr_dist = attr_dist['undefined_attributes'] - - dist_by_label = stats['mask_distribution_in_label'] - dist_by_attr = stats['mask_distribution_in_attribute'] - - defined_labels = defined_attr_dist.keys() - for label_name in defined_labels: - attr_stats = defined_attr_dist[label_name] - - reports += self._check_attribute_defined_but_not_found( - label_name, attr_stats) - - for attr_name, attr_dets in attr_stats.items(): - reports += self._check_few_samples_in_attribute( - label_name, attr_name, attr_dets) - reports += self._check_imbalanced_attribute( - label_name, attr_name, attr_dets) - reports += self._check_only_one_attribute_value( - label_name, attr_name, attr_dets) - reports += self._check_missing_attribute( - label_name, attr_name, attr_dets) - - mask_label_stats = dist_by_label[label_name] - mask_attr_label = dist_by_attr.get(label_name, {}) - - reports += self._check_far_from_label_mean( - label_name, mask_label_stats) - reports += self._check_imbalanced_dist_in_label( - label_name, mask_label_stats) - - for attr_name, mask_attr_stats in mask_attr_label.items(): - reports += self._check_far_from_attr_mean( - label_name, attr_name, mask_attr_stats) - reports += self._check_imbalanced_dist_in_attr( - label_name, attr_name, mask_attr_stats) - - for label_name, label_stats in undefined_label_dist.items(): - reports += self._check_undefined_label(label_name, label_stats) - - for label_name, attr_stats in undefined_attr_dist.items(): - for attr_name, attr_dets in attr_stats.items(): - reports += self._check_undefined_attribute( - label_name, attr_name, attr_dets) - - return reports - - -def validate_annotations(dataset: IDataset, task_type: Union[str, TaskType], **extra_args): - """ - Returns the validation results of a dataset based on task type. - - Args: - dataset (IDataset): Dataset to be validated - task_type (str or TaskType): Type of the task - (classification, detection, segmentation) - - Raises: - ValueError - - Returns: - validation_results (dict): - Dict with validation statistics, reports and summary. - - """ - - few_samples_thr = extra_args['few_samples_thr'] - imbalance_ratio_thr = extra_args['imbalance_ratio_thr'] - far_from_mean_thr = extra_args['far_from_mean_thr'] - dominance_ratio_thr = extra_args['dominance_ratio_thr'] - topk_bins = extra_args['topk_bins'] - - validation_results = {} - - task_type = parse_str_enum_value(task_type, TaskType) - if task_type == TaskType.classification: - validator = ClassificationValidator(few_samples_thr=few_samples_thr, - imbalance_ratio_thr=imbalance_ratio_thr, - far_from_mean_thr=far_from_mean_thr, - dominance_ratio_thr=dominance_ratio_thr, - topk_bins=topk_bins) - elif task_type == TaskType.detection: - validator = DetectionValidator(few_samples_thr=few_samples_thr, - imbalance_ratio_thr=imbalance_ratio_thr, - far_from_mean_thr=far_from_mean_thr, - dominance_ratio_thr=dominance_ratio_thr, - topk_bins=topk_bins) - elif task_type == TaskType.segmentation: - validator = SegmentationValidator(few_samples_thr=few_samples_thr, - imbalance_ratio_thr=imbalance_ratio_thr, - far_from_mean_thr=far_from_mean_thr, - dominance_ratio_thr=dominance_ratio_thr, - topk_bins=topk_bins) - - if not isinstance(dataset, IDataset): - raise TypeError("Invalid dataset type '%s'" % type(dataset)) - - # generate statistics - stats = validator.compute_statistics(dataset) - validation_results['statistics'] = stats +class IValidator: + def validate(self, dataset: IDataset) -> Dict: + raise NotImplementedError() - # generate validation reports and summary - reports = validator.generate_reports(stats) - reports = list(map(lambda r: r.to_dict(), reports)) - summary = { - 'errors': sum(map(lambda r: r['severity'] == 'error', reports)), - 'warnings': sum(map(lambda r: r['severity'] == 'warning', reports)) - } +class Validator(IValidator): + def validate(self, dataset: IDataset) -> Dict: + raise NotImplementedError() - validation_results['validation_reports'] = reports - validation_results['summary'] = summary + def compute_statistics(self, dataset: IDataset) -> Dict: + raise NotImplementedError() - return validation_results + def generate_reports(self, stats: Dict) -> List[Dict]: + raise NotImplementedError() diff --git a/datumaro/plugins/validators.py b/datumaro/plugins/validators.py new file mode 100644 index 000000000000..ce171c8208b2 --- /dev/null +++ b/datumaro/plugins/validators.py @@ -0,0 +1,1275 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from copy import deepcopy +from typing import Dict, List + +import json +import logging as log + +import numpy as np + +from datumaro.components.validator import (Severity, TaskType, Validator) +from datumaro.components.cli_plugin import CliPlugin +from datumaro.components.dataset import IDataset +from datumaro.components.errors import (MissingLabelCategories, + MissingAnnotation, MultiLabelAnnotations, MissingAttribute, + UndefinedLabel, UndefinedAttribute, LabelDefinedButNotFound, + AttributeDefinedButNotFound, OnlyOneLabel, FewSamplesInLabel, + FewSamplesInAttribute, ImbalancedLabels, ImbalancedAttribute, + ImbalancedDistInLabel, ImbalancedDistInAttribute, + NegativeLength, InvalidValue, FarFromLabelMean, + FarFromAttrMean, OnlyOneAttributeValue) +from datumaro.components.extractor import AnnotationType, LabelCategories +from datumaro.util import parse_str_enum_value + + +class _TaskValidator(Validator): + # statistics templates + numerical_stat_template = { + 'items_far_from_mean': {}, + 'mean': None, + 'stdev': None, + 'min': None, + 'max': None, + 'median': None, + 'histogram': { + 'bins': [], + 'counts': [], + }, + 'distribution': np.array([]) + } + + """ + A base class for task-specific validators. + + Attributes + ---------- + task_type : str or TaskType + task type (ie. classification, detection, segmentation) + + Methods + ------- + validate(dataset): + Validate annotations based on task type. + compute_statistics(dataset): + Computes various statistics of the dataset based on task type. + generate_reports(stats): + Abstract method that must be implemented in a subclass. + """ + + def __init__(self, task_type, few_samples_thr=None, + imbalance_ratio_thr=None, far_from_mean_thr=None, + dominance_ratio_thr=None, topk_bins=None): + """ + Validator + + Parameters + --------------- + few_samples_thr: int + minimum number of samples per class + warn user when samples per class is less than threshold + imbalance_ratio_thr: int + ratio of majority attribute to minority attribute + warn user when annotations are unevenly distributed + far_from_mean_thr: float + constant used to define mean +/- m * stddev + warn user when there are too big or small values + dominance_ratio_thr: float + ratio of Top-k bin to total + warn user when dominance ratio is over threshold + topk_bins: float + ratio of selected bins with most item number to total bins + warn user when values are not evenly distributed + """ + self.task_type = parse_str_enum_value(task_type, TaskType, + default=TaskType.classification) + + if self.task_type == TaskType.classification: + self.ann_types = {AnnotationType.label} + self.str_ann_type = "label" + elif self.task_type == TaskType.detection: + self.ann_types = {AnnotationType.bbox} + self.str_ann_type = "bounding box" + elif self.task_type == TaskType.segmentation: + self.ann_types = {AnnotationType.mask, AnnotationType.polygon} + self.str_ann_type = "mask or polygon" + + self.few_samples_thr = few_samples_thr + self.imbalance_ratio_thr = imbalance_ratio_thr + self.far_from_mean_thr = far_from_mean_thr + self.dominance_thr = dominance_ratio_thr + self.topk_bins_ratio = topk_bins + + def validate(self, dataset: IDataset): + """ + Returns the validation results of a dataset based on task type. + Args: + dataset (IDataset): Dataset to be validated + task_type (str or TaskType): Type of the task + (classification, detection, segmentation) + Raises: + ValueError + Returns: + validation_results (dict): + Dict with validation statistics, reports and summary. + """ + validation_results = {} + if not isinstance(dataset, IDataset): + raise TypeError("Invalid dataset type '%s'" % type(dataset)) + + # generate statistics + stats = self.compute_statistics(dataset) + validation_results['statistics'] = stats + + # generate validation reports and summary + reports = self.generate_reports(stats) + reports = list(map(lambda r: r.to_dict(), reports)) + + summary = { + 'errors': sum(map(lambda r: r['severity'] == 'error', reports)), + 'warnings': sum(map(lambda r: r['severity'] == 'warning', reports)) + } + + validation_results['validation_reports'] = reports + validation_results['summary'] = summary + + return validation_results + + def _compute_common_statistics(self, dataset): + defined_attr_template = { + 'items_missing_attribute': [], + 'distribution': {} + } + undefined_attr_template = { + 'items_with_undefined_attr': [], + 'distribution': {} + } + undefined_label_template = { + 'count': 0, + 'items_with_undefined_label': [], + } + + stats = { + 'label_distribution': { + 'defined_labels': {}, + 'undefined_labels': {}, + }, + 'attribute_distribution': { + 'defined_attributes': {}, + 'undefined_attributes': {} + }, + } + stats['total_ann_count'] = 0 + stats['items_missing_annotation'] = [] + + label_dist = stats['label_distribution'] + attr_dist = stats['attribute_distribution'] + defined_label_dist = label_dist['defined_labels'] + defined_attr_dist = attr_dist['defined_attributes'] + undefined_label_dist = label_dist['undefined_labels'] + undefined_attr_dist = attr_dist['undefined_attributes'] + + label_categories = dataset.categories().get(AnnotationType.label, + LabelCategories()) + base_valid_attrs = label_categories.attributes + + for category in label_categories: + defined_label_dist[category.name] = 0 + + filtered_anns = [] + for item in dataset: + item_key = (item.id, item.subset) + annotations = [] + for ann in item.annotations: + if ann.type in self.ann_types: + annotations.append(ann) + ann_count = len(annotations) + filtered_anns.append((item_key, annotations)) + + if ann_count == 0: + stats['items_missing_annotation'].append(item_key) + stats['total_ann_count'] += ann_count + + for ann in annotations: + if not 0 <= ann.label < len(label_categories): + label_name = ann.label + + label_stats = undefined_label_dist.setdefault( + ann.label, deepcopy(undefined_label_template)) + label_stats['items_with_undefined_label'].append( + item_key) + + label_stats['count'] += 1 + valid_attrs = set() + missing_attrs = set() + else: + label_name = label_categories[ann.label].name + defined_label_dist[label_name] += 1 + + defined_attr_stats = defined_attr_dist.setdefault( + label_name, {}) + + valid_attrs = base_valid_attrs.union( + label_categories[ann.label].attributes) + ann_attrs = getattr(ann, 'attributes', {}).keys() + missing_attrs = valid_attrs.difference(ann_attrs) + + for attr in valid_attrs: + defined_attr_stats.setdefault( + attr, deepcopy(defined_attr_template)) + + for attr in missing_attrs: + attr_dets = defined_attr_stats[attr] + attr_dets['items_missing_attribute'].append( + item_key) + + for attr, value in ann.attributes.items(): + if attr not in valid_attrs: + undefined_attr_stats = \ + undefined_attr_dist.setdefault( + label_name, {}) + attr_dets = undefined_attr_stats.setdefault( + attr, deepcopy(undefined_attr_template)) + attr_dets['items_with_undefined_attr'].append( + item_key) + else: + attr_dets = defined_attr_stats[attr] + + attr_dets['distribution'].setdefault(str(value), 0) + attr_dets['distribution'][str(value)] += 1 + + return stats, filtered_anns + + @staticmethod + def _update_prop_distributions(curr_prop_stats, target_stats): + for prop, val in curr_prop_stats.items(): + prop_stats = target_stats[prop] + prop_dist = prop_stats['distribution'] + prop_stats['distribution'] = np.append(prop_dist, val) + + @staticmethod + def _compute_prop_stats_from_dist(dist_by_label, dist_by_attr): + for label_name, stats in dist_by_label.items(): + prop_stats_list = list(stats.values()) + attr_label = dist_by_attr.get(label_name, {}) + for vals in attr_label.values(): + for val_stats in vals.values(): + prop_stats_list += list(val_stats.values()) + + for prop_stats in prop_stats_list: + prop_dist = prop_stats.pop('distribution', []) + if len(prop_dist) > 0: + prop_stats['mean'] = np.mean(prop_dist) + prop_stats['stdev'] = np.std(prop_dist) + prop_stats['min'] = np.min(prop_dist) + prop_stats['max'] = np.max(prop_dist) + prop_stats['median'] = np.median(prop_dist) + + counts, bins = np.histogram(prop_dist) + prop_stats['histogram']['bins'] = bins.tolist() + prop_stats['histogram']['counts'] = counts.tolist() + + def _compute_far_from_mean(self, prop_stats, val, item_key, ann): + def _far_from_mean(val, mean, stdev): + thr = self.far_from_mean_thr + return val > mean + (thr * stdev) or val < mean - (thr * stdev) + + mean = prop_stats['mean'] + stdev = prop_stats['stdev'] + + if _far_from_mean(val, mean, stdev): + items_far_from_mean = prop_stats['items_far_from_mean'] + far_from_mean = items_far_from_mean.setdefault( + item_key, {}) + far_from_mean[ann.id] = val + + def compute_statistics(self, dataset: IDataset): + """ + Computes statistics of the dataset based on task type. + + Parameters + ---------- + dataset : IDataset object + + Returns + ------- + stats (dict): A dict object containing statistics of the dataset. + """ + return NotImplementedError + + def _check_missing_label_categories(self, stats): + validation_reports = [] + + if len(stats['label_distribution']['defined_labels']) == 0: + validation_reports += self._generate_validation_report( + MissingLabelCategories, Severity.error) + + return validation_reports + + def _check_missing_annotation(self, stats): + validation_reports = [] + + items_missing = stats['items_missing_annotation'] + for item_id, item_subset in items_missing: + validation_reports += self._generate_validation_report( + MissingAnnotation, Severity.warning, item_id, item_subset, + self.str_ann_type) + + return validation_reports + + def _check_missing_attribute(self, label_name, attr_name, attr_dets): + validation_reports = [] + + items_missing_attr = attr_dets['items_missing_attribute'] + for item_id, item_subset in items_missing_attr: + details = (item_subset, label_name, attr_name) + validation_reports += self._generate_validation_report( + MissingAttribute, Severity.warning, item_id, *details) + + return validation_reports + + def _check_undefined_label(self, label_name, label_stats): + validation_reports = [] + + items_with_undefined_label = label_stats['items_with_undefined_label'] + for item_id, item_subset in items_with_undefined_label: + details = (item_subset, label_name) + validation_reports += self._generate_validation_report( + UndefinedLabel, Severity.error, item_id, *details) + + return validation_reports + + def _check_undefined_attribute(self, label_name, attr_name, attr_dets): + validation_reports = [] + + items_with_undefined_attr = attr_dets['items_with_undefined_attr'] + for item_id, item_subset in items_with_undefined_attr: + details = (item_subset, label_name, attr_name) + validation_reports += self._generate_validation_report( + UndefinedAttribute, Severity.error, item_id, *details) + + return validation_reports + + def _check_label_defined_but_not_found(self, stats): + validation_reports = [] + count_by_defined_labels = stats['label_distribution']['defined_labels'] + labels_not_found = [label_name + for label_name, count in count_by_defined_labels.items() + if count == 0] + + for label_name in labels_not_found: + validation_reports += self._generate_validation_report( + LabelDefinedButNotFound, Severity.warning, label_name) + + return validation_reports + + def _check_attribute_defined_but_not_found(self, label_name, attr_stats): + validation_reports = [] + attrs_not_found = [attr_name + for attr_name, attr_dets in attr_stats.items() + if len(attr_dets['distribution']) == 0] + + for attr_name in attrs_not_found: + details = (label_name, attr_name) + validation_reports += self._generate_validation_report( + AttributeDefinedButNotFound, Severity.warning, *details) + + return validation_reports + + def _check_only_one_label(self, stats): + validation_reports = [] + count_by_defined_labels = stats['label_distribution']['defined_labels'] + labels_found = [label_name + for label_name, count in count_by_defined_labels.items() + if count > 0] + + if len(labels_found) == 1: + validation_reports += self._generate_validation_report( + OnlyOneLabel, Severity.warning, labels_found[0]) + + return validation_reports + + def _check_only_one_attribute_value(self, label_name, attr_name, attr_dets): + validation_reports = [] + values = list(attr_dets['distribution'].keys()) + + if len(values) == 1: + details = (label_name, attr_name, values[0]) + validation_reports += self._generate_validation_report( + OnlyOneAttributeValue, Severity.warning, *details) + + return validation_reports + + def _check_few_samples_in_label(self, stats): + validation_reports = [] + thr = self.few_samples_thr + + defined_label_dist = stats['label_distribution']['defined_labels'] + labels_with_few_samples = [(label_name, count) + for label_name, count in defined_label_dist.items() + if 0 < count <= thr] + + for label_name, count in labels_with_few_samples: + validation_reports += self._generate_validation_report( + FewSamplesInLabel, Severity.warning, label_name, count) + + return validation_reports + + def _check_few_samples_in_attribute(self, label_name, + attr_name, attr_dets): + validation_reports = [] + thr = self.few_samples_thr + + attr_values_with_few_samples = [(attr_value, count) + for attr_value, count in attr_dets['distribution'].items() + if count <= thr] + + for attr_value, count in attr_values_with_few_samples: + details = (label_name, attr_name, attr_value, count) + validation_reports += self._generate_validation_report( + FewSamplesInAttribute, Severity.warning, *details) + + return validation_reports + + def _check_imbalanced_labels(self, stats): + validation_reports = [] + thr = self.imbalance_ratio_thr + + defined_label_dist = stats['label_distribution']['defined_labels'] + count_by_defined_labels = [count + for label, count in defined_label_dist.items()] + + if len(count_by_defined_labels) == 0: + return validation_reports + + count_max = np.max(count_by_defined_labels) + count_min = np.min(count_by_defined_labels) + balance = count_max / count_min if count_min > 0 else float('inf') + if balance >= thr: + validation_reports += self._generate_validation_report( + ImbalancedLabels, Severity.warning) + + return validation_reports + + def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets): + validation_reports = [] + thr = self.imbalance_ratio_thr + + count_by_defined_attr = list(attr_dets['distribution'].values()) + if len(count_by_defined_attr) == 0: + return validation_reports + + count_max = np.max(count_by_defined_attr) + count_min = np.min(count_by_defined_attr) + balance = count_max / count_min if count_min > 0 else float('inf') + if balance >= thr: + validation_reports += self._generate_validation_report( + ImbalancedAttribute, Severity.warning, label_name, attr_name) + + return validation_reports + + def _check_imbalanced_dist_in_label(self, label_name, label_stats): + validation_reports = [] + thr = self.dominance_thr + topk_ratio = self.topk_bins_ratio + + for prop, prop_stats in label_stats.items(): + value_counts = prop_stats['histogram']['counts'] + n_bucket = len(value_counts) + if n_bucket < 2: + continue + topk = max(1, int(np.around(n_bucket * topk_ratio))) + + if topk > 0: + topk_values = np.sort(value_counts)[-topk:] + ratio = np.sum(topk_values) / np.sum(value_counts) + if ratio >= thr: + details = (label_name, f"{self.str_ann_type} {prop}") + validation_reports += self._generate_validation_report( + ImbalancedDistInLabel, Severity.warning, *details) + + return validation_reports + + def _check_imbalanced_dist_in_attr(self, label_name, attr_name, attr_stats): + validation_reports = [] + thr = self.dominance_thr + topk_ratio = self.topk_bins_ratio + + for attr_value, value_stats in attr_stats.items(): + for prop, prop_stats in value_stats.items(): + value_counts = prop_stats['histogram']['counts'] + n_bucket = len(value_counts) + if n_bucket < 2: + continue + topk = max(1, int(np.around(n_bucket * topk_ratio))) + + if topk > 0: + topk_values = np.sort(value_counts)[-topk:] + ratio = np.sum(topk_values) / np.sum(value_counts) + if ratio >= thr: + details = (label_name, attr_name, attr_value, + f"{self.str_ann_type} {prop}") + validation_reports += self._generate_validation_report( + ImbalancedDistInAttribute, + Severity.warning, + *details + ) + + return validation_reports + + def _check_invalid_value(self, stats): + validation_reports = [] + + items_w_invalid_val = stats['items_with_invalid_value'] + for item_dets, anns_w_invalid_val in items_w_invalid_val.items(): + item_id, item_subset = item_dets + for ann_id, props in anns_w_invalid_val.items(): + for prop in props: + details = (item_subset, ann_id, + f"{self.str_ann_type} {prop}") + validation_reports += self._generate_validation_report( + InvalidValue, Severity.error, item_id, *details) + + return validation_reports + + def _check_far_from_label_mean(self, label_name, label_stats): + validation_reports = [] + + for prop, prop_stats in label_stats.items(): + items_far_from_mean = prop_stats['items_far_from_mean'] + if prop_stats['mean'] is not None: + mean = round(prop_stats['mean'], 2) + + for item_dets, anns_far in items_far_from_mean.items(): + item_id, item_subset = item_dets + for ann_id, val in anns_far.items(): + val = round(val, 2) + details = (item_subset, label_name, ann_id, + f"{self.str_ann_type} {prop}", mean, val) + validation_reports += self._generate_validation_report( + FarFromLabelMean, Severity.warning, item_id, *details) + + return validation_reports + + def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats): + validation_reports = [] + + for attr_value, value_stats in attr_stats.items(): + for prop, prop_stats in value_stats.items(): + items_far_from_mean = prop_stats['items_far_from_mean'] + if prop_stats['mean'] is not None: + mean = round(prop_stats['mean'], 2) + + for item_dets, anns_far in items_far_from_mean.items(): + item_id, item_subset = item_dets + for ann_id, val in anns_far.items(): + val = round(val, 2) + details = (item_subset, label_name, ann_id, attr_name, + attr_value, f"{self.str_ann_type} {prop}", + mean, val) + validation_reports += self._generate_validation_report( + FarFromAttrMean, + Severity.warning, + item_id, + *details + ) + + return validation_reports + + def generate_reports(self, stats: Dict) -> List[Dict]: + raise NotImplementedError('Should be implemented in a subclass.') + + def _generate_validation_report(self, error, *args, **kwargs): + return [error(*args, **kwargs)] + + +class ClassificationValidator(_TaskValidator, CliPlugin): + """ + A specific validator class for classification task. + """ + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, + help="Threshold for giving a warning for minimum number of" + "samples per class") + parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, + help="Threshold for giving data imbalance warning;" + "IR(imbalance ratio) = majority/minority") + parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, + help="Threshold for giving a warning that data is far from mean;" + "A constant used to define mean +/- k * standard deviation;") + parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, + help="Threshold for giving a warning for bounding box imbalance;" + "Dominace_ratio = ratio of Top-k bin to total in histogram;") + parser.add_argument('-k', '--topk_bins', default=0.1, type=float, + help="Ratio of bins with the highest number of data" + "to total bins in the histogram; [0, 1]; 0.1 = 10%;") + return parser + + def __init__(self, few_samples_thr, imbalance_ratio_thr, + far_from_mean_thr, dominance_ratio_thr, topk_bins): + super().__init__(task_type=TaskType.classification, + few_samples_thr=few_samples_thr, + imbalance_ratio_thr=imbalance_ratio_thr, + far_from_mean_thr=far_from_mean_thr, + dominance_ratio_thr=dominance_ratio_thr, topk_bins=topk_bins) + + def _check_multi_label_annotations(self, stats): + validation_reports = [] + + items_with_multiple_labels = stats['items_with_multiple_labels'] + for item_id, item_subset in items_with_multiple_labels: + validation_reports += self._generate_validation_report( + MultiLabelAnnotations, Severity.error, item_id, item_subset) + + return validation_reports + + def compute_statistics(self, dataset): + """ + Computes statistics of the dataset for the classification task. + + Parameters + ---------- + dataset : IDataset object + + Returns + ------- + stats (dict): A dict object containing statistics of the dataset. + """ + + stats, filtered_anns = self._compute_common_statistics(dataset) + + stats['items_with_multiple_labels'] = [] + + for item_key, anns in filtered_anns: + ann_count = len(anns) + if ann_count > 1: + stats['items_with_multiple_labels'].append(item_key) + + return stats + + def generate_reports(self, stats): + """ + Validates the dataset for classification tasks based on its statistics. + + Parameters + ---------- + dataset : IDataset object + stats: Dict object + + Returns + ------- + reports (list): List of validation reports (DatasetValidationError). + """ + + reports = [] + + reports += self._check_missing_label_categories(stats) + reports += self._check_missing_annotation(stats) + reports += self._check_multi_label_annotations(stats) + reports += self._check_label_defined_but_not_found(stats) + reports += self._check_only_one_label(stats) + reports += self._check_few_samples_in_label(stats) + reports += self._check_imbalanced_labels(stats) + + label_dist = stats['label_distribution'] + attr_dist = stats['attribute_distribution'] + defined_attr_dist = attr_dist['defined_attributes'] + undefined_label_dist = label_dist['undefined_labels'] + undefined_attr_dist = attr_dist['undefined_attributes'] + + defined_labels = defined_attr_dist.keys() + for label_name in defined_labels: + attr_stats = defined_attr_dist[label_name] + + reports += self._check_attribute_defined_but_not_found( + label_name, attr_stats) + + for attr_name, attr_dets in attr_stats.items(): + reports += self._check_few_samples_in_attribute( + label_name, attr_name, attr_dets) + reports += self._check_imbalanced_attribute( + label_name, attr_name, attr_dets) + reports += self._check_only_one_attribute_value( + label_name, attr_name, attr_dets) + reports += self._check_missing_attribute( + label_name, attr_name, attr_dets) + + for label_name, label_stats in undefined_label_dist.items(): + reports += self._check_undefined_label(label_name, label_stats) + + for label_name, attr_stats in undefined_attr_dist.items(): + for attr_name, attr_dets in attr_stats.items(): + reports += self._check_undefined_attribute( + label_name, attr_name, attr_dets) + + return reports + + +class DetectionValidator(_TaskValidator, CliPlugin): + """ + A specific validator class for detection task. + """ + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, + help="Threshold for giving a warning for minimum number of" + "samples per class") + parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, + help="Threshold for giving data imbalance warning;" + "IR(imbalance ratio) = majority/minority") + parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, + help="Threshold for giving a warning that data is far from mean;" + "A constant used to define mean +/- k * standard deviation;") + parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, + help="Threshold for giving a warning for bounding box imbalance;" + "Dominace_ratio = ratio of Top-k bin to total in histogram;") + parser.add_argument('-k', '--topk_bins', default=0.1, type=float, + help="Ratio of bins with the highest number of data" + "to total bins in the histogram; [0, 1]; 0.1 = 10%;") + return parser + + def __init__(self, few_samples_thr, imbalance_ratio_thr, + far_from_mean_thr, dominance_ratio_thr, topk_bins): + super().__init__(task_type=TaskType.detection, + few_samples_thr=few_samples_thr, + imbalance_ratio_thr=imbalance_ratio_thr, + far_from_mean_thr=far_from_mean_thr, + dominance_ratio_thr=dominance_ratio_thr, topk_bins=topk_bins) + + def _check_negative_length(self, stats): + validation_reports = [] + + items_w_neg_len = stats['items_with_negative_length'] + for item_dets, anns_w_neg_len in items_w_neg_len.items(): + item_id, item_subset = item_dets + for ann_id, props in anns_w_neg_len.items(): + for prop, val in props.items(): + val = round(val, 2) + details = (item_subset, ann_id, + f"{self.str_ann_type} {prop}", val) + validation_reports += self._generate_validation_report( + NegativeLength, Severity.error, item_id, *details) + + return validation_reports + + def compute_statistics(self, dataset): + """ + Computes statistics of the dataset for the detection task. + + Parameters + ---------- + dataset : IDataset object + + Returns + ------- + stats (dict): A dict object containing statistics of the dataset. + """ + + stats, filtered_anns = self._compute_common_statistics(dataset) + + # detection-specific + bbox_template = { + 'width': deepcopy(self.numerical_stat_template), + 'height': deepcopy(self.numerical_stat_template), + 'area(wxh)': deepcopy(self.numerical_stat_template), + 'ratio(w/h)': deepcopy(self.numerical_stat_template), + 'short': deepcopy(self.numerical_stat_template), + 'long': deepcopy(self.numerical_stat_template) + } + + stats['items_with_negative_length'] = {} + stats['items_with_invalid_value'] = {} + stats['bbox_distribution_in_label'] = {} + stats['bbox_distribution_in_attribute'] = {} + stats['bbox_distribution_in_dataset_item'] = {} + + dist_by_label = stats['bbox_distribution_in_label'] + dist_by_attr = stats['bbox_distribution_in_attribute'] + bbox_dist_in_item = stats['bbox_distribution_in_dataset_item'] + items_w_neg_len = stats['items_with_negative_length'] + items_w_invalid_val = stats['items_with_invalid_value'] + + def _generate_ann_bbox_info(_x, _y, _w, _h, area, + ratio, _short, _long): + return { + 'x': _x, + 'y': _y, + 'width': _w, + 'height': _h, + 'area(wxh)': area, + 'ratio(w/h)': ratio, + 'short': _short, + 'long': _long, + } + + def _update_bbox_stats_by_label(item_key, ann, bbox_label_stats): + bbox_has_error = False + + _x, _y, _w, _h = ann.get_bbox() + area = ann.get_area() + + if _h != 0 and _h != float('inf'): + ratio = _w / _h + else: + ratio = float('nan') + + _short = _w if _w < _h else _h + _long = _w if _w > _h else _h + + ann_bbox_info = _generate_ann_bbox_info( + _x, _y, _w, _h, area, ratio, _short, _long) + + for prop, val in ann_bbox_info.items(): + if val == float('inf') or np.isnan(val): + bbox_has_error = True + anns_w_invalid_val = items_w_invalid_val.setdefault( + item_key, {}) + invalid_props = anns_w_invalid_val.setdefault( + ann.id, []) + invalid_props.append(prop) + + for prop in ['width', 'height']: + val = ann_bbox_info[prop] + if val < 1: + bbox_has_error = True + anns_w_neg_len = items_w_neg_len.setdefault( + item_key, {}) + neg_props = anns_w_neg_len.setdefault(ann.id, {}) + neg_props[prop] = val + + if not bbox_has_error: + ann_bbox_info.pop('x') + ann_bbox_info.pop('y') + self._update_prop_distributions(ann_bbox_info, bbox_label_stats) + + return ann_bbox_info, bbox_has_error + + label_categories = dataset.categories().get(AnnotationType.label, + LabelCategories()) + base_valid_attrs = label_categories.attributes + + for item_key, annotations in filtered_anns: + ann_count = len(annotations) + + bbox_dist_in_item[item_key] = ann_count + + for ann in annotations: + if not 0 <= ann.label < len(label_categories): + label_name = ann.label + valid_attrs = set() + else: + label_name = label_categories[ann.label].name + valid_attrs = base_valid_attrs.union( + label_categories[ann.label].attributes) + + bbox_label_stats = dist_by_label.setdefault( + label_name, deepcopy(bbox_template)) + ann_bbox_info, bbox_has_error = \ + _update_bbox_stats_by_label( + item_key, ann, bbox_label_stats) + + for attr, value in ann.attributes.items(): + if attr in valid_attrs: + bbox_attr_label = dist_by_attr.setdefault( + label_name, {}) + bbox_attr_stats = bbox_attr_label.setdefault( + attr, {}) + bbox_val_stats = bbox_attr_stats.setdefault( + str(value), deepcopy(bbox_template)) + + if not bbox_has_error: + self._update_prop_distributions( + ann_bbox_info, bbox_val_stats) + + # Compute prop stats from distribution + self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr) + + def _is_valid_ann(item_key, ann): + has_defined_label = 0 <= ann.label < len(label_categories) + if not has_defined_label: + return False + + bbox_has_neg_len = ann.id in items_w_neg_len.get( + item_key, {}) + bbox_has_invalid_val = ann.id in items_w_invalid_val.get( + item_key, {}) + return not (bbox_has_neg_len or bbox_has_invalid_val) + + def _update_props_far_from_mean(item_key, ann): + valid_attrs = base_valid_attrs.union( + label_categories[ann.label].attributes) + label_name = label_categories[ann.label].name + bbox_label_stats = dist_by_label[label_name] + + _x, _y, _w, _h = ann.get_bbox() + area = ann.get_area() + ratio = _w / _h + _short = _w if _w < _h else _h + _long = _w if _w > _h else _h + + ann_bbox_info = _generate_ann_bbox_info( + _x, _y, _w, _h, area, ratio, _short, _long) + ann_bbox_info.pop('x') + ann_bbox_info.pop('y') + + for prop, val in ann_bbox_info.items(): + prop_stats = bbox_label_stats[prop] + self._compute_far_from_mean(prop_stats, val, item_key, ann) + + for attr, value in ann.attributes.items(): + if attr in valid_attrs: + bbox_attr_stats = dist_by_attr[label_name][attr] + bbox_val_stats = bbox_attr_stats[str(value)] + + for prop, val in ann_bbox_info.items(): + prop_stats = bbox_val_stats[prop] + self._compute_far_from_mean(prop_stats, val, + item_key, ann) + + for item_key, annotations in filtered_anns: + for ann in annotations: + if _is_valid_ann(item_key, ann): + _update_props_far_from_mean(item_key, ann) + + return stats + + def generate_reports(self, stats): + """ + Validates the dataset for detection tasks based on its statistics. + + Parameters + ---------- + dataset : IDataset object + stats : Dict object + + Returns + ------- + reports (list): List of validation reports (DatasetValidationError). + """ + + reports = [] + + reports += self._check_missing_label_categories(stats) + reports += self._check_missing_annotation(stats) + reports += self._check_label_defined_but_not_found(stats) + reports += self._check_only_one_label(stats) + reports += self._check_few_samples_in_label(stats) + reports += self._check_imbalanced_labels(stats) + reports += self._check_negative_length(stats) + reports += self._check_invalid_value(stats) + + label_dist = stats['label_distribution'] + attr_dist = stats['attribute_distribution'] + defined_attr_dist = attr_dist['defined_attributes'] + undefined_label_dist = label_dist['undefined_labels'] + undefined_attr_dist = attr_dist['undefined_attributes'] + + dist_by_label = stats['bbox_distribution_in_label'] + dist_by_attr = stats['bbox_distribution_in_attribute'] + + defined_labels = defined_attr_dist.keys() + for label_name in defined_labels: + attr_stats = defined_attr_dist[label_name] + + reports += self._check_attribute_defined_but_not_found( + label_name, attr_stats) + + for attr_name, attr_dets in attr_stats.items(): + reports += self._check_few_samples_in_attribute( + label_name, attr_name, attr_dets) + reports += self._check_imbalanced_attribute( + label_name, attr_name, attr_dets) + reports += self._check_only_one_attribute_value( + label_name, attr_name, attr_dets) + reports += self._check_missing_attribute( + label_name, attr_name, attr_dets) + + bbox_label_stats = dist_by_label[label_name] + bbox_attr_label = dist_by_attr.get(label_name, {}) + + reports += self._check_far_from_label_mean( + label_name, bbox_label_stats) + reports += self._check_imbalanced_dist_in_label( + label_name, bbox_label_stats) + + for attr_name, bbox_attr_stats in bbox_attr_label.items(): + reports += self._check_far_from_attr_mean( + label_name, attr_name, bbox_attr_stats) + reports += self._check_imbalanced_dist_in_attr( + label_name, attr_name, bbox_attr_stats) + + for label_name, label_stats in undefined_label_dist.items(): + reports += self._check_undefined_label(label_name, label_stats) + + for label_name, attr_stats in undefined_attr_dist.items(): + for attr_name, attr_dets in attr_stats.items(): + reports += self._check_undefined_attribute( + label_name, attr_name, attr_dets) + + return reports + + +class SegmentationValidator(_TaskValidator, CliPlugin): + """ + A specific validator class for (instance) segmentation task. + """ + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, + help="Threshold for giving a warning for minimum number of" + "samples per class") + parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, + help="Threshold for giving data imbalance warning;" + "IR(imbalance ratio) = majority/minority") + parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, + help="Threshold for giving a warning that data is far from mean;" + "A constant used to define mean +/- k * standard deviation;") + parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, + help="Threshold for giving a warning for bounding box imbalance;" + "Dominace_ratio = ratio of Top-k bin to total in histogram;") + parser.add_argument('-k', '--topk_bins', default=0.1, type=float, + help="Ratio of bins with the highest number of data" + "to total bins in the histogram; [0, 1]; 0.1 = 10%;") + return parser + + def __init__(self, few_samples_thr, imbalance_ratio_thr, + far_from_mean_thr, dominance_ratio_thr, topk_bins): + super().__init__(task_type=TaskType.segmentation, + few_samples_thr=few_samples_thr, + imbalance_ratio_thr=imbalance_ratio_thr, + far_from_mean_thr=far_from_mean_thr, + dominance_ratio_thr=dominance_ratio_thr, topk_bins=topk_bins) + + def compute_statistics(self, dataset): + """ + Computes statistics of the dataset for the segmentation task. + + Parameters + ---------- + dataset : IDataset object + + Returns + ------- + stats (dict): A dict object containing statistics of the dataset. + """ + + stats, filtered_anns = self._compute_common_statistics(dataset) + + # segmentation-specific + mask_template = { + 'area': deepcopy(self.numerical_stat_template), + 'width': deepcopy(self.numerical_stat_template), + 'height': deepcopy(self.numerical_stat_template) + } + + stats['items_with_invalid_value'] = {} + stats['mask_distribution_in_label'] = {} + stats['mask_distribution_in_attribute'] = {} + stats['mask_distribution_in_dataset_item'] = {} + + dist_by_label = stats['mask_distribution_in_label'] + dist_by_attr = stats['mask_distribution_in_attribute'] + mask_dist_in_item = stats['mask_distribution_in_dataset_item'] + items_w_invalid_val = stats['items_with_invalid_value'] + + def _generate_ann_mask_info(area, _w, _h): + return { + 'area': area, + 'width': _w, + 'height': _h, + } + + def _update_mask_stats_by_label(item_key, ann, mask_label_stats): + mask_has_error = False + + _x, _y, _w, _h = ann.get_bbox() + + # Detete the following block when #226 is resolved + # https://github.com/openvinotoolkit/datumaro/issues/226 + if ann.type == AnnotationType.mask: + _w += 1 + _h += 1 + + area = ann.get_area() + + ann_mask_info = _generate_ann_mask_info(area, _w, _h) + + for prop, val in ann_mask_info.items(): + if val == float('inf') or np.isnan(val): + mask_has_error = True + anns_w_invalid_val = items_w_invalid_val.setdefault( + item_key, {}) + invalid_props = anns_w_invalid_val.setdefault( + ann.id, []) + invalid_props.append(prop) + + if not mask_has_error: + self._update_prop_distributions(ann_mask_info, mask_label_stats) + + return ann_mask_info, mask_has_error + + label_categories = dataset.categories().get(AnnotationType.label, + LabelCategories()) + base_valid_attrs = label_categories.attributes + + for item_key, annotations in filtered_anns: + ann_count = len(annotations) + mask_dist_in_item[item_key] = ann_count + + for ann in annotations: + if not 0 <= ann.label < len(label_categories): + label_name = ann.label + valid_attrs = set() + else: + label_name = label_categories[ann.label].name + valid_attrs = base_valid_attrs.union( + label_categories[ann.label].attributes) + + mask_label_stats = dist_by_label.setdefault( + label_name, deepcopy(mask_template)) + ann_mask_info, mask_has_error = \ + _update_mask_stats_by_label( + item_key, ann, mask_label_stats) + + for attr, value in ann.attributes.items(): + if attr in valid_attrs: + mask_attr_label = dist_by_attr.setdefault( + label_name, {}) + mask_attr_stats = mask_attr_label.setdefault( + attr, {}) + mask_val_stats = mask_attr_stats.setdefault( + str(value), deepcopy(mask_template)) + + if not mask_has_error: + self._update_prop_distributions( + ann_mask_info, mask_val_stats) + + # compute prop stats from dist. + self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr) + + def _is_valid_ann(item_key, ann): + has_defined_label = 0 <= ann.label < len(label_categories) + if not has_defined_label: + return False + + mask_has_invalid_val = ann.id in items_w_invalid_val.get( + item_key, {}) + return not mask_has_invalid_val + + def _update_props_far_from_mean(item_key, ann): + valid_attrs = base_valid_attrs.union( + label_categories[ann.label].attributes) + label_name = label_categories[ann.label].name + mask_label_stats = dist_by_label[label_name] + + _x, _y, _w, _h = ann.get_bbox() + + # Detete the following block when #226 is resolved + # https://github.com/openvinotoolkit/datumaro/issues/226 + if ann.type == AnnotationType.mask: + _w += 1 + _h += 1 + area = ann.get_area() + + ann_mask_info = _generate_ann_mask_info(area, _w, _h) + + for prop, val in ann_mask_info.items(): + prop_stats = mask_label_stats[prop] + self._compute_far_from_mean(prop_stats, val, item_key, ann) + + for attr, value in ann.attributes.items(): + if attr in valid_attrs: + mask_attr_stats = dist_by_attr[label_name][attr] + mask_val_stats = mask_attr_stats[str(value)] + + for prop, val in ann_mask_info.items(): + prop_stats = mask_val_stats[prop] + self._compute_far_from_mean(prop_stats, val, + item_key, ann) + + for item_key, annotations in filtered_anns: + for ann in annotations: + if _is_valid_ann(item_key, ann): + _update_props_far_from_mean(item_key, ann) + + return stats + + def generate_reports(self, stats): + """ + Validates the dataset for segmentation tasks based on its statistics. + + Parameters + ---------- + dataset : IDataset object + stats : Dict object + + Returns + ------- + reports (list): List of validation reports (DatasetValidationError). + """ + + reports = [] + + reports += self._check_missing_label_categories(stats) + reports += self._check_missing_annotation(stats) + reports += self._check_label_defined_but_not_found(stats) + reports += self._check_only_one_label(stats) + reports += self._check_few_samples_in_label(stats) + reports += self._check_imbalanced_labels(stats) + reports += self._check_invalid_value(stats) + + label_dist = stats['label_distribution'] + attr_dist = stats['attribute_distribution'] + defined_attr_dist = attr_dist['defined_attributes'] + undefined_label_dist = label_dist['undefined_labels'] + undefined_attr_dist = attr_dist['undefined_attributes'] + + dist_by_label = stats['mask_distribution_in_label'] + dist_by_attr = stats['mask_distribution_in_attribute'] + + defined_labels = defined_attr_dist.keys() + for label_name in defined_labels: + attr_stats = defined_attr_dist[label_name] + + reports += self._check_attribute_defined_but_not_found( + label_name, attr_stats) + + for attr_name, attr_dets in attr_stats.items(): + reports += self._check_few_samples_in_attribute( + label_name, attr_name, attr_dets) + reports += self._check_imbalanced_attribute( + label_name, attr_name, attr_dets) + reports += self._check_only_one_attribute_value( + label_name, attr_name, attr_dets) + reports += self._check_missing_attribute( + label_name, attr_name, attr_dets) + + mask_label_stats = dist_by_label[label_name] + mask_attr_label = dist_by_attr.get(label_name, {}) + + reports += self._check_far_from_label_mean( + label_name, mask_label_stats) + reports += self._check_imbalanced_dist_in_label( + label_name, mask_label_stats) + + for attr_name, mask_attr_stats in mask_attr_label.items(): + reports += self._check_far_from_attr_mean( + label_name, attr_name, mask_attr_stats) + reports += self._check_imbalanced_dist_in_attr( + label_name, attr_name, mask_attr_stats) + + for label_name, label_stats in undefined_label_dist.items(): + reports += self._check_undefined_label(label_name, label_stats) + + for label_name, attr_stats in undefined_attr_dist.items(): + for attr_name, attr_dets in attr_stats.items(): + reports += self._check_undefined_attribute( + label_name, attr_name, attr_dets) + + return reports diff --git a/docs/user_manual.md b/docs/user_manual.md index d3781a10202a..766779661fcb 100644 --- a/docs/user_manual.md +++ b/docs/user_manual.md @@ -907,20 +907,38 @@ and stores the result in JSON file. The task types supported are `classification`, `detection`, and `segmentation`. The validation result contains -- annotation statistics based on the task type -- validation reports, such as +- `annotation statistics` based on the task type +- `validation reports`, such as - items not having annotations - items having undefined annotations - imbalanced distribution in class/attributes - too small or large values -- summary +- `summary` Usage: +- There are five configurable parameters for validation + - `few_samples_thr` : threshold for giving a warning for minimum number of samples per class + - `imbalance_ratio_thr` : threshold for giving imbalance data warning + - `far_from_mean_thr` : threshold for giving a warning that data is far from mean + - `dominance_ratio_thr` : threshold for giving a warning bounding box imbalance + - `topk_bins` : ratio of bins with the highest number of data to total bins in the histogram ``` bash datum validate --help -datum validate -p +datum validate -p -t -- \ + -fs \ + -ir \ + -m \ + -dr \ + -k +``` + +Example : give warning when imbalance ratio of data with classification task over 40 + +``` bash +datum validate -p prj-cls -t classification -- \ + -ir 40 ``` Here is the list of validation items(a.k.a. anomaly types). diff --git a/tests/test_validator.py b/tests/test_validator.py index 2d0bd47e7877..7f855e4cb8c3 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -4,6 +4,7 @@ from collections import Counter from unittest import TestCase + import numpy as np from datumaro.components.dataset import Dataset, DatasetItem @@ -16,9 +17,8 @@ NegativeLength, InvalidValue, FarFromLabelMean, FarFromAttrMean, OnlyOneAttributeValue) from datumaro.components.extractor import Bbox, Label, Mask, Polygon -from datumaro.components.validator import (ClassificationValidator, - DetectionValidator, TaskType, validate_annotations, _Validator, - SegmentationValidator) +from datumaro.components.validator import TaskType +from datumaro.plugins.validators import (_TaskValidator, ClassificationValidator, DetectionValidator, SegmentationValidator) from .requirements import Requirements, mark_requirement @@ -114,7 +114,7 @@ def setUpClass(cls): class TestBaseValidator(TestValidatorTemplate): @classmethod def setUpClass(cls): - cls.validator = _Validator(task_type=TaskType.classification, + cls.validator = _TaskValidator(task_type=TaskType.classification, few_samples_thr=1, imbalance_ratio_thr=50, far_from_mean_thr=5.0, dominance_ratio_thr=0.8, topk_bins=0.1) @@ -721,8 +721,8 @@ class TestValidateAnnotations(TestValidatorTemplate): } @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_validate_annotations_classification(self): - actual_results = validate_annotations(self.dataset, 'classification', - **self.extra_args) + validator = ClassificationValidator(**self.extra_args) + actual_results = validator.validate(self.dataset) with self.subTest('Test of statistics', i=0): actual_stats = actual_results['statistics'] @@ -778,8 +778,8 @@ def test_validate_annotations_classification(self): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_validate_annotations_detection(self): - actual_results = validate_annotations(self.dataset, 'detection', - **self.extra_args) + validator = DetectionValidator(**self.extra_args) + actual_results = validator.validate(self.dataset) with self.subTest('Test of statistics', i=0): actual_stats = actual_results['statistics'] @@ -833,8 +833,8 @@ def test_validate_annotations_detection(self): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_validate_annotations_segmentation(self): - actual_results = validate_annotations(self.dataset, 'segmentation', - **self.extra_args) + validator = SegmentationValidator(**self.extra_args) + actual_results = validator.validate(self.dataset) with self.subTest('Test of statistics', i=0): actual_stats = actual_results['statistics'] @@ -888,11 +888,7 @@ def test_validate_annotations_segmentation(self): self.assertEqual(actual_summary, expected_summary) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_validate_annotations_invalid_task_type(self): - with self.assertRaises(ValueError): - validate_annotations(self.dataset, 'INVALID', **self.extra_args) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_validate_annotations_invalid_dataset_type(self): + def test_validate_invalid_dataset_type(self): with self.assertRaises(TypeError): - validate_annotations(object(), 'classification', **self.extra_args) + validator = ClassificationValidator(**self.extra_args) + validator.validate(object())