Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New functionality for counting TFRecord examples #144

Merged
merged 5 commits into from
Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,30 @@ $ cvdata_mask --images /data/images --masks /data/masks \
> --tfrecords /data/tfrecords \
> --shards 4 -- train_pct 0.8
```
## Dataset statistics
Basic statistics about a dataset are available via the script `cvdata/analyze.py`
or the corresponding script entry point `cvdata_analyze`.

For example, we can count the number of examples in a collection of TFRecord files
(specify a directory containing only TFRecod files):
```bash
$ cvdata_analyze --format tfrecord --annotations /data/animals/tfrecord
Total number of examples: 100
```
The above functionality can be utilized within Python code like so:
```python
from cvdata.analyze import count_tfrecord_examples
tfrecords_dir = "/data/animals/tfrecord"
number_of_examples = count_tfrecord_examples(tfrecords_dir)
print(f"Number of examples: {number_of_examples}")
```
For datasets containing annotation files in COCO, Darknet (YOLO), KITTI, or PASCAL
formats we can get the number of images per class label. For example:
```bash
$ cvdata_analyze --format kitti --annotations /data/scissors/kitti --images /data/scissors/images
Label: scissors Count: 100
```

## Visualize annotations
In order to visualize images and corresponding annotations use the script
`cvdata/visualize.py` or the corresponding script entry point `cvdata_visualize`.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="cvdata",
version="0.0.6",
version="0.0.7",
author="James Adams",
author_email="[email protected]",
description="Tools for creating and manipulating computer vision datasets",
Expand Down
2 changes: 2 additions & 0 deletions src/cvdata/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# version of the cvdata package, should match with setup.py
__version__ = "0.0.7"
122 changes: 75 additions & 47 deletions src/cvdata/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xml.etree import ElementTree

import pandas as pd
import tensorflow as tf

import cvdata.common
from cvdata.utils import matching_ids
Expand Down Expand Up @@ -138,6 +139,26 @@ def count_labels(
raise ValueError(f"Unsupported annotation format: \"{annotation_format}\"")


# ------------------------------------------------------------------------------
def count_tfrecord_examples(
tfrecords_dir: str,
) -> int:
"""
Counts the total number of examples in a collection of TFRecord files.

:param tfrecords_dir: directory that is assumed to contain only TFRecord files
:return: the total number of examples in the collection of TFRecord files
found in the specified directory
"""

count = 0
for file_name in os.listdir(tfrecords_dir):
tfrecord_path = os.path.join(tfrecords_dir, file_name)
count += sum(1 for _ in tf.data.TFRecordDataset(tfrecord_path))

return count


# ------------------------------------------------------------------------------
def main():

Expand All @@ -151,7 +172,7 @@ def main():
)
args_parser.add_argument(
"--images",
required=True,
required=False,
type=str,
help="images directory path",
)
Expand All @@ -172,64 +193,71 @@ def main():
)
args = vars(args_parser.parse_args())

# the two dictionaries we'll build for final reporting
label_counts = {}
label_file_ids = {}
if args["format"] == "tfrecord":

if args["format"] == "openimages":
# count and report the examples in the collection of TFRecord files
examples_count = count_tfrecord_examples(args["annotations"])
print(f"Total number of examples: {examples_count}")

# read the OpenImages CSV into a pandas DataFrame
df_annotations = pd.read_csv(args["annotations"])
df_annotations = df_annotations[['ImageID', 'LabelName']]
else:
# the two dictionaries we'll build for final reporting
label_counts = {}
label_file_ids = {}

# TODO get another dataframe from the class descriptions and get the
# readable label names from there to map to the LabelName column
if args["format"] == "openimages":

# whittle it down to only the rows that match to image IDs
file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])]
df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)]
# read the OpenImages CSV into a pandas DataFrame
df_annotations = pd.read_csv(args["annotations"])
df_annotations = df_annotations[['ImageID', 'LabelName']]

# TODO populate the label counts and label file IDs dictionaries
# TODO get another dataframe from the class descriptions and get the
# readable label names from there to map to the LabelName column

else:
# whittle it down to only the rows that match to image IDs
file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])]
df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)]

# TODO populate the label counts and label file IDs dictionaries

annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]]
else:

# only annotations matching to the images are considered to be valid
file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg")
annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]]

for file_id in file_ids:
annotation_file_path = \
os.path.join(args["annotations"], file_id + annotation_ext)
# only annotations matching to the images are considered to be valid
file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg")

# get the images per label count
for label, count in count_labels(annotation_file_path, args["format"]).items():
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
for file_id in file_ids:
annotation_file_path = \
os.path.join(args["annotations"], file_id + annotation_ext)

# for each label found in the annotation file add this file ID
# to the set of file IDs corresponding to the label
if args["file_ids"]:
if label in label_file_ids:
# add this file ID to the existing set for the label
label_file_ids[label].add(file_id)
# get the images per label count
for label, count in count_labels(annotation_file_path, args["format"]).items():
if label in label_counts:
label_counts[label] += 1
else:
# first file ID seen for this label so create new set
label_file_ids[label] = {file_id}

# write the images per label counts
for label, count in label_counts.items():
print(f"Label: {label}\t\tCount: {count}")

# write the label ID files, if requested
if args["file_ids"]:
for label, file_ids_for_label in label_file_ids.items():
label_file_ids_path = os.path.join(args["file_ids"], label + ".txt")
with open(label_file_ids_path, "w") as label_file_ids_file:
for file_id in file_ids_for_label:
label_file_ids_file.write(f"{file_id}\n")
label_counts[label] = 1

# for each label found in the annotation file add this file ID
# to the set of file IDs corresponding to the label
if args["file_ids"]:
if label in label_file_ids:
# add this file ID to the existing set for the label
label_file_ids[label].add(file_id)
else:
# first file ID seen for this label so create new set
label_file_ids[label] = {file_id}

# write the images per label counts
for label, count in label_counts.items():
print(f"Label: {label}\t\tCount: {count}")

# write the label ID files, if requested
if args["file_ids"]:
for label, file_ids_for_label in label_file_ids.items():
label_file_ids_path = os.path.join(args["file_ids"], label + ".txt")
with open(label_file_ids_path, "w") as label_file_ids_file:
for file_id in file_ids_for_label:
label_file_ids_file.write(f"{file_id}\n")


# ------------------------------------------------------------------------------
Expand Down
17 changes: 17 additions & 0 deletions tests/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,20 @@ def test_count_labels(
assert label_counts["1"] == 2
assert label_counts["2"] == 1
assert label_counts["3"] == 1


# ------------------------------------------------------------------------------
@pytest.mark.usefixtures(
"data_dir",
)
def test_count_tfrecord_examples(
data_dir,
):
"""
Test for the cvdata.analyze.count_tfrecord_examples() function

:param data_dir: temporary directory into which test files will be loaded
"""
tfrecord_dir = os.path.join(str(data_dir), "tfrecord")
example_count = analyze.count_tfrecord_examples(tfrecord_dir)
assert example_count == 100
Binary file not shown.