Skip to content

Commit

Permalink
[Datumaro] Mean and std for dataset (cvat-ai#1734)
Browse files Browse the repository at this point in the history
* Add meanstd

* Add stats cli

* Update changelog

Co-authored-by: Nikita Manovich <[email protected]>
  • Loading branch information
2 people authored and Fernando Martínez González committed Aug 3, 2020
1 parent de6fd07 commit 7b99c6b
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Built-in search for labels when create an object or change a label (<https://github.com/opencv/cvat/pull/1683>)
- Better validation of labels and attributes in raw viewer (<https://github.com/opencv/cvat/pull/1727>)
- ClamAV antivirus integration (<https://github.com/opencv/cvat/pull/1712>)
- [Datumaro] Added `stats` command, which shows some dataset statistics like image mean and std (https://github.com/opencv/cvat/pull/1734)
- Add option to upload annotations upon task creation on CLI
- Polygon and polylines interpolation (<https://github.com/opencv/cvat/pull/1571>)
- Ability to redraw shape from scratch (Shift + N) for an activated shape (<https://github.com/opencv/cvat/pull/1571>)
Expand Down
34 changes: 34 additions & 0 deletions datumaro/datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datumaro.components.dataset_filter import DatasetItemEncoder
from datumaro.components.extractor import AnnotationType
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.operations import mean_std
from .diff import DiffVisualizer
from ...util import add_subparser, CliException, MultilineFormatter, \
make_file_name
Expand Down Expand Up @@ -623,6 +624,38 @@ def transform_command(args):

return 0

def build_stats_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Get project statistics",
description="""
Outputs project statistics.
""",
formatter_class=MultilineFormatter)

parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the project to operate on (default: current dir)")
parser.set_defaults(command=stats_command)

return parser

def stats_command(args):
project = load_project(args.project_dir)
dataset = project.make_dataset()

def print_extractor_info(extractor, indent=''):
mean, std = mean_std(dataset)
print("%sImage mean:" % indent, ', '.join('%.3f' % n for n in mean))
print("%sImage std:" % indent, ', '.join('%.3f' % n for n in std))

print("Dataset: ")
print_extractor_info(dataset)

if 1 < len(dataset.subsets()):
print("Subsets: ")
for subset_name in dataset.subsets():
subset = dataset.get_subset(subset_name)
print(" %s:" % subset_name)
print_extractor_info(subset, " " * 4)

def build_info_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Get project info",
description="""
Expand Down Expand Up @@ -718,5 +751,6 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
add_subparser(subparsers, 'diff', build_diff_parser)
add_subparser(subparsers, 'transform', build_transform_parser)
add_subparser(subparsers, 'info', build_info_parser)
add_subparser(subparsers, 'stats', build_stats_parser)

return parser
82 changes: 82 additions & 0 deletions datumaro/datumaro/components/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import cv2
import numpy as np


def mean_std(dataset):
"""
Computes unbiased mean and std. dev. for dataset images, channel-wise.
"""
# Use an online algorithm to:
# - handle different image sizes
# - avoid cancellation problem

stats = np.empty((len(dataset), 2, 3), dtype=np.double)
counts = np.empty(len(dataset), dtype=np.uint32)

mean = lambda i, s: s[i][0]
var = lambda i, s: s[i][1]

for i, item in enumerate(dataset):
counts[i] = np.prod(item.image.size)

image = item.image.data
if len(image.shape) == 2:
image = image[:, :, np.newaxis]
else:
image = image[:, :, :3]
# opencv is much faster than numpy here
cv2.meanStdDev(image.astype(np.double) / 255,
mean=mean(i, stats), stddev=var(i, stats))

# make variance unbiased
np.multiply(np.square(stats[:, 1]),
(counts / (counts - 1))[:, np.newaxis],
out=stats[:, 1])

_, mean, var = StatsCounter().compute_stats(stats, counts, mean, var)
return mean * 255, np.sqrt(var) * 255

class StatsCounter:
# Implements online parallel computation of sample variance
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

# Needed do avoid catastrophic cancellation in floating point computations
@staticmethod
def pairwise_stats(count_a, mean_a, var_a, count_b, mean_b, var_b):
delta = mean_b - mean_a
m_a = var_a * (count_a - 1)
m_b = var_b * (count_b - 1)
M2 = m_a + m_b + delta ** 2 * count_a * count_b / (count_a + count_b)
return (
count_a + count_b,
mean_a * 0.5 + mean_b * 0.5,
M2 / (count_a + count_b - 1)
)

# stats = float array of shape N, 2 * d, d = dimensions of values
# count = integer array of shape N
# mean_accessor = function(idx, stats) to retrieve element mean
# variance_accessor = function(idx, stats) to retrieve element variance
# Recursively computes total count, mean and variance, does O(log(N)) calls
@staticmethod
def compute_stats(stats, counts, mean_accessor, variance_accessor):
m = mean_accessor
v = variance_accessor
n = len(stats)
if n == 1:
return counts[0], m(0, stats), v(0, stats)
if n == 2:
return __class__.pairwise_stats(
counts[0], m(0, stats), v(0, stats),
counts[1], m(1, stats), v(1, stats)
)
h = n // 2
return __class__.pairwise_stats(
*__class__.compute_stats(stats[:h], counts[:h], m, v),
*__class__.compute_stats(stats[h:], counts[h:], m, v)
)
31 changes: 31 additions & 0 deletions datumaro/tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

from datumaro.components.extractor import Extractor, DatasetItem
from datumaro.components.operations import mean_std

from unittest import TestCase


class TestOperations(TestCase):
def test_mean_std(self):
expected_mean = [100, 50, 150]
expected_std = [20, 50, 10]

class TestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, image=np.random.normal(
expected_mean, expected_std,
size=(w, h, 3))
)
for i, (w, h) in enumerate([
(3000, 100), (800, 600), (400, 200), (700, 300)
])
])

actual_mean, actual_std = mean_std(TestExtractor())

for em, am in zip(expected_mean, actual_mean):
self.assertAlmostEqual(em, am, places=0)
for estd, astd in zip(expected_std, actual_std):
self.assertAlmostEqual(estd, astd, places=0)

0 comments on commit 7b99c6b

Please sign in to comment.