Skip to content

Commit

Permalink
validator threshold adjustment + style correction
Browse files Browse the repository at this point in the history
  • Loading branch information
Yi, Jihyeon committed Mar 29, 2021
1 parent 1fcba0b commit 2d9c954
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 124 deletions.
116 changes: 69 additions & 47 deletions datumaro/components/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@


class _Validator:
DEFAULT_FEW_SAMPLES = 1
DEFAULT_IMBALANCE_RATIO = 50
"""
A base class for task-specific validators.
Expand Down Expand Up @@ -57,7 +59,10 @@ def __init__(self, task_type=None, ann_type=None, far_from_mean_thr=None):

self.task_type = task_type
self.ann_type = ann_type

self.far_from_mean_thr = far_from_mean_thr
self.imbalance_ratio_thr = self.DEFAULT_IMBALANCE_RATIO
self.few_samples_thr = self.DEFAULT_FEW_SAMPLES

def compute_statistics(self, dataset):
"""
Expand Down Expand Up @@ -300,7 +305,7 @@ def _update_props_far_from_mean(item, ann):
defined_label_dist[category.name] = 0

for item in dataset:
ann_count = [ann.type == self.ann_type \
ann_count = [ann.type == self.ann_type
for ann in item.annotations].count(True)

if self.task_type == TaskType.classification:
Expand Down Expand Up @@ -371,7 +376,7 @@ def _update_props_far_from_mean(item, ann):
attr_dets = defined_attr_stats[attr]

if self.task_type == TaskType.detection and \
ann.type == self.ann_type:
ann.type == self.ann_type:
bbox_attr_label = bbox_dist_by_attr.setdefault(
label_name, {})
bbox_attr_stats = bbox_attr_label.setdefault(
Expand Down Expand Up @@ -441,8 +446,8 @@ def _check_undefined_attribute(self, label_name, attr_name, attr_dets):
def _check_label_defined_but_not_found(self, stats):
validation_reports = []
count_by_defined_labels = stats['label_distribution']['defined_labels']
labels_not_found = [label_name \
for label_name, count in count_by_defined_labels.items() \
labels_not_found = [label_name
for label_name, count in count_by_defined_labels.items()
if count == 0]

for label_name in labels_not_found:
Expand All @@ -453,8 +458,8 @@ def _check_label_defined_but_not_found(self, stats):

def _check_attribute_defined_but_not_found(self, label_name, attr_stats):
validation_reports = []
attrs_not_found = [attr_name \
for attr_name, attr_dets in attr_stats.items() \
attrs_not_found = [attr_name
for attr_name, attr_dets in attr_stats.items()
if len(attr_dets['distribution']) == 0]

for attr_name in attrs_not_found:
Expand All @@ -467,8 +472,8 @@ def _check_attribute_defined_but_not_found(self, label_name, attr_stats):
def _check_only_one_label(self, stats):
validation_reports = []
count_by_defined_labels = stats['label_distribution']['defined_labels']
labels_found = [label_name \
for label_name, count in count_by_defined_labels.items() \
labels_found = [label_name
for label_name, count in count_by_defined_labels.items()
if count > 0]

if len(labels_found) == 1:
Expand All @@ -488,12 +493,14 @@ def _check_only_one_attribute_value(self, label_name, attr_name, attr_dets):

return validation_reports

def _check_few_samples_in_label(self, stats, thr):
def _check_few_samples_in_label(self, stats):
validation_reports = []
thr = self.few_samples_thr

defined_label_dist = stats['label_distribution']['defined_labels']
labels_with_few_samples = [(label_name, count) \
for label_name, count in defined_label_dist.items() \
if 0 < count < thr]
labels_with_few_samples = [(label_name, count)
for label_name, count in defined_label_dist.items()
if 0 < count <= thr]

for label_name, count in labels_with_few_samples:
validation_reports += self._generate_validation_report(
Expand All @@ -502,11 +509,13 @@ def _check_few_samples_in_label(self, stats, thr):
return validation_reports

def _check_few_samples_in_attribute(self, label_name,
attr_name, attr_dets, thr):
attr_name, attr_dets):
validation_reports = []
attr_values_with_few_samples = [(attr_value, count) \
for attr_value, count in attr_dets['distribution'].items() \
if count < thr]
thr = self.few_samples_thr

attr_values_with_few_samples = [(attr_value, count)
for attr_value, count in attr_dets['distribution'].items()
if count <= thr]

for attr_value, count in attr_values_with_few_samples:
details = (label_name, attr_name, attr_value, count)
Expand All @@ -515,11 +524,12 @@ def _check_few_samples_in_attribute(self, label_name,

return validation_reports

def _check_imbalanced_labels(self, stats, thr):
def _check_imbalanced_labels(self, stats):
validation_reports = []
thr = self.imbalance_ratio_thr

defined_label_dist = stats['label_distribution']['defined_labels']
count_by_defined_labels = [count \
count_by_defined_labels = [count
for label, count in defined_label_dist.items()]

if len(count_by_defined_labels) == 0:
Expand All @@ -528,15 +538,15 @@ def _check_imbalanced_labels(self, stats, thr):
count_max = np.max(count_by_defined_labels)
count_min = np.min(count_by_defined_labels)
balance = count_max / count_min if count_min > 0 else float('inf')
if balance > thr:
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedLabels, Severity.warning)

return validation_reports

def _check_imbalanced_attribute(self, label_name, attr_name,
attr_dets, thr):
def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
thr = self.imbalance_ratio_thr

count_by_defined_attr = list(attr_dets['distribution'].values())
if len(count_by_defined_attr) == 0:
Expand All @@ -545,7 +555,7 @@ def _check_imbalanced_attribute(self, label_name, attr_name,
count_max = np.max(count_by_defined_attr)
count_min = np.min(count_by_defined_attr)
balance = count_max / count_min if count_min > 0 else float('inf')
if balance > thr:
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedAttribute, Severity.warning, label_name, attr_name)

Expand Down Expand Up @@ -607,8 +617,8 @@ def generate_reports(self, stats):
reports += self._check_multi_label_annotations(stats)
reports += self._check_label_defined_but_not_found(stats)
reports += self._check_only_one_label(stats)
reports += self._check_few_samples_in_label(stats, 2)
reports += self._check_imbalanced_labels(stats, 5)
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)

label_dist = stats['label_distribution']
attr_dist = stats['attribute_distribution']
Expand All @@ -625,9 +635,9 @@ def generate_reports(self, stats):

for attr_name, attr_dets in attr_stats.items():
reports += self._check_few_samples_in_attribute(
label_name, attr_name, attr_dets, 2)
label_name, attr_name, attr_dets)
reports += self._check_imbalanced_attribute(
label_name, attr_name, attr_dets, 5)
label_name, attr_name, attr_dets)
reports += self._check_only_one_attribute_value(
label_name, attr_name, attr_dets)
reports += self._check_missing_attribute(
Expand All @@ -649,45 +659,57 @@ class DetectionValidator(_Validator):
A validator class for detection tasks.
"""

DEFAULT_FAR_FROM_MEAN = 2.0
DEFAULT_FAR_FROM_MEAN = 5.0
DEFAULT_BBOX_IMBALANCE = 0.8
DEFAULT_BBOX_TOPK_BINS = 0.1

def __init__(self):
super().__init__(TaskType.detection, AnnotationType.bbox,
far_from_mean_thr=self.DEFAULT_FAR_FROM_MEAN)
self.bbox_imbalance_thr = self.DEFAULT_BBOX_IMBALANCE
self.bbox_topk_bins_ratio = self.DEFAULT_BBOX_TOPK_BINS

def _check_imbalanced_bbox_dist_in_label(self, label_name, bbox_label_stats,
thr, topk_ratio):
def _check_imbalanced_bbox_dist_in_label(self, label_name,
bbox_label_stats):
validation_reports = []
thr = self.bbox_imbalance_thr
topk_ratio = self.bbox_topk_bins_ratio

for prop, prop_stats in bbox_label_stats.items():
value_counts = prop_stats['histogram']['counts']
n_bucket = len(value_counts)
topk = int(np.around(n_bucket * topk_ratio))
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))

if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio > thr:
if ratio >= thr:
details = (label_name, prop)
validation_reports += self._generate_validation_report(
ImbalancedBboxDistInLabel, Severity.warning, *details)

return validation_reports

def _check_imbalanced_bbox_dist_in_attr(self, label_name, attr_name,
bbox_attr_stats, thr, topk_ratio):
bbox_attr_stats):
validation_reports = []
thr = self.bbox_imbalance_thr
topk_ratio = self.bbox_topk_bins_ratio

for attr_value, value_stats in bbox_attr_stats.items():
for prop, prop_stats in value_stats.items():
value_counts = prop_stats['histogram']['counts']
n_bucket = len(value_counts)
topk = int(np.around(n_bucket * topk_ratio))
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))

if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio > thr:
if ratio >= thr:
details = (label_name, attr_name, attr_value, prop)
validation_reports += self._generate_validation_report(
ImbalancedBboxDistInAttribute,
Expand Down Expand Up @@ -744,9 +766,9 @@ def _check_far_from_label_mean(self, label_name, bbox_label_stats):
if prop_stats['mean'] is not None:
mean = round(prop_stats['mean'], 2)

for item_dets, anns_far_from_mean in items_far_from_mean.items():
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far_from_mean.items():
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (item_subset, label_name, ann_id, prop, mean, val)
validation_reports += self._generate_validation_report(
Expand All @@ -763,9 +785,9 @@ def _check_far_from_attr_mean(self, label_name, attr_name, bbox_attr_stats):
if prop_stats['mean'] is not None:
mean = round(prop_stats['mean'], 2)

for item_dets, anns_far_from_mean in items_far_from_mean.items():
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far_from_mean.items():
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (item_subset, label_name, ann_id, attr_name,
attr_value, prop, mean, val)
Expand Down Expand Up @@ -798,8 +820,8 @@ def generate_reports(self, stats):
reports += self._check_missing_bbox_annotation(stats)
reports += self._check_label_defined_but_not_found(stats)
reports += self._check_only_one_label(stats)
reports += self._check_few_samples_in_label(stats, 2)
reports += self._check_imbalanced_labels(stats, 5)
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)
reports += self._check_negative_length(stats)
reports += self._check_invalid_value(stats)

Expand All @@ -821,9 +843,9 @@ def generate_reports(self, stats):

for attr_name, attr_dets in attr_stats.items():
reports += self._check_few_samples_in_attribute(
label_name, attr_name, attr_dets, 2)
label_name, attr_name, attr_dets)
reports += self._check_imbalanced_attribute(
label_name, attr_name, attr_dets, 5)
label_name, attr_name, attr_dets)
reports += self._check_only_one_attribute_value(
label_name, attr_name, attr_dets)
reports += self._check_missing_attribute(
Expand All @@ -835,13 +857,13 @@ def generate_reports(self, stats):
reports += self._check_far_from_label_mean(
label_name, bbox_label_stats)
reports += self._check_imbalanced_bbox_dist_in_label(
label_name, bbox_label_stats, 1, 0.25)
label_name, bbox_label_stats)

for attr_name, bbox_attr_stats in bbox_attr_label.items():
reports += self._check_far_from_attr_mean(
label_name, attr_name, bbox_attr_stats)
reports += self._check_imbalanced_bbox_dist_in_attr(
label_name, attr_name, bbox_attr_stats, 1, 0.25)
label_name, attr_name, bbox_attr_stats)

for label_name, label_stats in undefined_label_dist.items():
reports += self._check_undefined_label(label_name, label_stats)
Expand Down Expand Up @@ -889,11 +911,11 @@ def validate_annotations(dataset: IDataset, task_type: Union[str, TaskType]):

# generate validation reports and summary
reports = validator.generate_reports(stats)
reports = list(map(lambda r : r.to_dict(), reports))
reports = list(map(lambda r: r.to_dict(), reports))

summary = {
'errors': sum(map(lambda r : r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r : r['severity'] == 'warning', reports))
'errors': sum(map(lambda r: r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r: r['severity'] == 'warning', reports))
}

validation_results['validation_reports'] = reports
Expand Down
Loading

0 comments on commit 2d9c954

Please sign in to comment.