Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kate/validator imbalance thr #190

Merged
merged 2 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 69 additions & 47 deletions datumaro/components/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@


class _Validator:
DEFAULT_FEW_SAMPLES = 1
DEFAULT_IMBALANCE_RATIO = 50
"""
A base class for task-specific validators.

Expand Down Expand Up @@ -57,7 +59,10 @@ def __init__(self, task_type=None, ann_type=None, far_from_mean_thr=None):

self.task_type = task_type
self.ann_type = ann_type

self.far_from_mean_thr = far_from_mean_thr
self.imbalance_ratio_thr = self.DEFAULT_IMBALANCE_RATIO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider putting this into the ClassificationValidator, if this is not supposed to be used anywhere else.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is commonly used by ClassificationValidator and DetectionValidator. So I think this is the right place.

self.few_samples_thr = self.DEFAULT_FEW_SAMPLES

def compute_statistics(self, dataset):
"""
Expand Down Expand Up @@ -300,7 +305,7 @@ def _update_props_far_from_mean(item, ann):
defined_label_dist[category.name] = 0

for item in dataset:
ann_count = [ann.type == self.ann_type \
ann_count = [ann.type == self.ann_type
for ann in item.annotations].count(True)

if self.task_type == TaskType.classification:
Expand Down Expand Up @@ -371,7 +376,7 @@ def _update_props_far_from_mean(item, ann):
attr_dets = defined_attr_stats[attr]

if self.task_type == TaskType.detection and \
ann.type == self.ann_type:
ann.type == self.ann_type:
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
bbox_attr_label = bbox_dist_by_attr.setdefault(
label_name, {})
bbox_attr_stats = bbox_attr_label.setdefault(
Expand Down Expand Up @@ -441,8 +446,8 @@ def _check_undefined_attribute(self, label_name, attr_name, attr_dets):
def _check_label_defined_but_not_found(self, stats):
validation_reports = []
count_by_defined_labels = stats['label_distribution']['defined_labels']
labels_not_found = [label_name \
for label_name, count in count_by_defined_labels.items() \
labels_not_found = [label_name
for label_name, count in count_by_defined_labels.items()
if count == 0]

for label_name in labels_not_found:
Expand All @@ -453,8 +458,8 @@ def _check_label_defined_but_not_found(self, stats):

def _check_attribute_defined_but_not_found(self, label_name, attr_stats):
validation_reports = []
attrs_not_found = [attr_name \
for attr_name, attr_dets in attr_stats.items() \
attrs_not_found = [attr_name
for attr_name, attr_dets in attr_stats.items()
if len(attr_dets['distribution']) == 0]

for attr_name in attrs_not_found:
Expand All @@ -467,8 +472,8 @@ def _check_attribute_defined_but_not_found(self, label_name, attr_stats):
def _check_only_one_label(self, stats):
validation_reports = []
count_by_defined_labels = stats['label_distribution']['defined_labels']
labels_found = [label_name \
for label_name, count in count_by_defined_labels.items() \
labels_found = [label_name
for label_name, count in count_by_defined_labels.items()
if count > 0]

if len(labels_found) == 1:
Expand All @@ -488,12 +493,14 @@ def _check_only_one_attribute_value(self, label_name, attr_name, attr_dets):

return validation_reports

def _check_few_samples_in_label(self, stats, thr):
def _check_few_samples_in_label(self, stats):
validation_reports = []
thr = self.few_samples_thr

defined_label_dist = stats['label_distribution']['defined_labels']
labels_with_few_samples = [(label_name, count) \
for label_name, count in defined_label_dist.items() \
if 0 < count < thr]
labels_with_few_samples = [(label_name, count)
for label_name, count in defined_label_dist.items()
if 0 < count <= thr]

for label_name, count in labels_with_few_samples:
validation_reports += self._generate_validation_report(
Expand All @@ -502,11 +509,13 @@ def _check_few_samples_in_label(self, stats, thr):
return validation_reports

def _check_few_samples_in_attribute(self, label_name,
attr_name, attr_dets, thr):
attr_name, attr_dets):
validation_reports = []
attr_values_with_few_samples = [(attr_value, count) \
for attr_value, count in attr_dets['distribution'].items() \
if count < thr]
thr = self.few_samples_thr

attr_values_with_few_samples = [(attr_value, count)
for attr_value, count in attr_dets['distribution'].items()
if count <= thr]

for attr_value, count in attr_values_with_few_samples:
details = (label_name, attr_name, attr_value, count)
Expand All @@ -515,11 +524,12 @@ def _check_few_samples_in_attribute(self, label_name,

return validation_reports

def _check_imbalanced_labels(self, stats, thr):
def _check_imbalanced_labels(self, stats):
validation_reports = []
thr = self.imbalance_ratio_thr

defined_label_dist = stats['label_distribution']['defined_labels']
count_by_defined_labels = [count \
count_by_defined_labels = [count
for label, count in defined_label_dist.items()]

if len(count_by_defined_labels) == 0:
Expand All @@ -528,15 +538,15 @@ def _check_imbalanced_labels(self, stats, thr):
count_max = np.max(count_by_defined_labels)
count_min = np.min(count_by_defined_labels)
balance = count_max / count_min if count_min > 0 else float('inf')
if balance > thr:
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedLabels, Severity.warning)

return validation_reports

def _check_imbalanced_attribute(self, label_name, attr_name,
attr_dets, thr):
def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
thr = self.imbalance_ratio_thr

count_by_defined_attr = list(attr_dets['distribution'].values())
if len(count_by_defined_attr) == 0:
Expand All @@ -545,7 +555,7 @@ def _check_imbalanced_attribute(self, label_name, attr_name,
count_max = np.max(count_by_defined_attr)
count_min = np.min(count_by_defined_attr)
balance = count_max / count_min if count_min > 0 else float('inf')
if balance > thr:
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedAttribute, Severity.warning, label_name, attr_name)

Expand Down Expand Up @@ -607,8 +617,8 @@ def generate_reports(self, stats):
reports += self._check_multi_label_annotations(stats)
reports += self._check_label_defined_but_not_found(stats)
reports += self._check_only_one_label(stats)
reports += self._check_few_samples_in_label(stats, 2)
reports += self._check_imbalanced_labels(stats, 5)
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)

label_dist = stats['label_distribution']
attr_dist = stats['attribute_distribution']
Expand All @@ -625,9 +635,9 @@ def generate_reports(self, stats):

for attr_name, attr_dets in attr_stats.items():
reports += self._check_few_samples_in_attribute(
label_name, attr_name, attr_dets, 2)
label_name, attr_name, attr_dets)
reports += self._check_imbalanced_attribute(
label_name, attr_name, attr_dets, 5)
label_name, attr_name, attr_dets)
reports += self._check_only_one_attribute_value(
label_name, attr_name, attr_dets)
reports += self._check_missing_attribute(
Expand All @@ -649,45 +659,57 @@ class DetectionValidator(_Validator):
A validator class for detection tasks.
"""

DEFAULT_FAR_FROM_MEAN = 2.0
DEFAULT_FAR_FROM_MEAN = 5.0
DEFAULT_BBOX_IMBALANCE = 0.8
DEFAULT_BBOX_TOPK_BINS = 0.1

def __init__(self):
super().__init__(TaskType.detection, AnnotationType.bbox,
far_from_mean_thr=self.DEFAULT_FAR_FROM_MEAN)
self.bbox_imbalance_thr = self.DEFAULT_BBOX_IMBALANCE
self.bbox_topk_bins_ratio = self.DEFAULT_BBOX_TOPK_BINS

def _check_imbalanced_bbox_dist_in_label(self, label_name, bbox_label_stats,
thr, topk_ratio):
def _check_imbalanced_bbox_dist_in_label(self, label_name,
bbox_label_stats):
validation_reports = []
thr = self.bbox_imbalance_thr
topk_ratio = self.bbox_topk_bins_ratio

for prop, prop_stats in bbox_label_stats.items():
value_counts = prop_stats['histogram']['counts']
n_bucket = len(value_counts)
topk = int(np.around(n_bucket * topk_ratio))
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))

if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio > thr:
if ratio >= thr:
details = (label_name, prop)
validation_reports += self._generate_validation_report(
ImbalancedBboxDistInLabel, Severity.warning, *details)

return validation_reports

def _check_imbalanced_bbox_dist_in_attr(self, label_name, attr_name,
bbox_attr_stats, thr, topk_ratio):
bbox_attr_stats):
validation_reports = []
thr = self.bbox_imbalance_thr
topk_ratio = self.bbox_topk_bins_ratio

for attr_value, value_stats in bbox_attr_stats.items():
for prop, prop_stats in value_stats.items():
value_counts = prop_stats['histogram']['counts']
n_bucket = len(value_counts)
topk = int(np.around(n_bucket * topk_ratio))
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))

if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio > thr:
if ratio >= thr:
details = (label_name, attr_name, attr_value, prop)
validation_reports += self._generate_validation_report(
ImbalancedBboxDistInAttribute,
Expand Down Expand Up @@ -744,9 +766,9 @@ def _check_far_from_label_mean(self, label_name, bbox_label_stats):
if prop_stats['mean'] is not None:
mean = round(prop_stats['mean'], 2)

for item_dets, anns_far_from_mean in items_far_from_mean.items():
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far_from_mean.items():
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (item_subset, label_name, ann_id, prop, mean, val)
validation_reports += self._generate_validation_report(
Expand All @@ -763,9 +785,9 @@ def _check_far_from_attr_mean(self, label_name, attr_name, bbox_attr_stats):
if prop_stats['mean'] is not None:
mean = round(prop_stats['mean'], 2)

for item_dets, anns_far_from_mean in items_far_from_mean.items():
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far_from_mean.items():
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (item_subset, label_name, ann_id, attr_name,
attr_value, prop, mean, val)
Expand Down Expand Up @@ -798,8 +820,8 @@ def generate_reports(self, stats):
reports += self._check_missing_bbox_annotation(stats)
reports += self._check_label_defined_but_not_found(stats)
reports += self._check_only_one_label(stats)
reports += self._check_few_samples_in_label(stats, 2)
reports += self._check_imbalanced_labels(stats, 5)
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)
reports += self._check_negative_length(stats)
reports += self._check_invalid_value(stats)

Expand All @@ -821,9 +843,9 @@ def generate_reports(self, stats):

for attr_name, attr_dets in attr_stats.items():
reports += self._check_few_samples_in_attribute(
label_name, attr_name, attr_dets, 2)
label_name, attr_name, attr_dets)
reports += self._check_imbalanced_attribute(
label_name, attr_name, attr_dets, 5)
label_name, attr_name, attr_dets)
reports += self._check_only_one_attribute_value(
label_name, attr_name, attr_dets)
reports += self._check_missing_attribute(
Expand All @@ -835,13 +857,13 @@ def generate_reports(self, stats):
reports += self._check_far_from_label_mean(
label_name, bbox_label_stats)
reports += self._check_imbalanced_bbox_dist_in_label(
label_name, bbox_label_stats, 1, 0.25)
label_name, bbox_label_stats)

for attr_name, bbox_attr_stats in bbox_attr_label.items():
reports += self._check_far_from_attr_mean(
label_name, attr_name, bbox_attr_stats)
reports += self._check_imbalanced_bbox_dist_in_attr(
label_name, attr_name, bbox_attr_stats, 1, 0.25)
label_name, attr_name, bbox_attr_stats)

for label_name, label_stats in undefined_label_dist.items():
reports += self._check_undefined_label(label_name, label_stats)
Expand Down Expand Up @@ -889,11 +911,11 @@ def validate_annotations(dataset: IDataset, task_type: Union[str, TaskType]):

# generate validation reports and summary
reports = validator.generate_reports(stats)
reports = list(map(lambda r : r.to_dict(), reports))
reports = list(map(lambda r: r.to_dict(), reports))

summary = {
'errors': sum(map(lambda r : r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r : r['severity'] == 'warning', reports))
'errors': sum(map(lambda r: r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r: r['severity'] == 'warning', reports))
}

validation_results['validation_reports'] = reports
Expand Down
Loading