Skip to content

Commit

Permalink
Merge pull request #144 from monocongo/issue_143_count_tfrecord_examples
Browse files Browse the repository at this point in the history
New functionality for counting TFRecord examples
  • Loading branch information
monocongo authored Feb 7, 2020
2 parents 661ac1f + 6f39dee commit 1be3c0d
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 48 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,30 @@ $ cvdata_mask --images /data/images --masks /data/masks \
> --tfrecords /data/tfrecords \
> --shards 4 -- train_pct 0.8
```
## Dataset statistics
Basic statistics about a dataset are available via the script `cvdata/analyze.py`
or the corresponding script entry point `cvdata_analyze`.

For example, we can count the number of examples in a collection of TFRecord files
(specify a directory containing only TFRecod files):
```bash
$ cvdata_analyze --format tfrecord --annotations /data/animals/tfrecord
Total number of examples: 100
```
The above functionality can be utilized within Python code like so:
```python
from cvdata.analyze import count_tfrecord_examples
tfrecords_dir = "/data/animals/tfrecord"
number_of_examples = count_tfrecord_examples(tfrecords_dir)
print(f"Number of examples: {number_of_examples}")
```
For datasets containing annotation files in COCO, Darknet (YOLO), KITTI, or PASCAL
formats we can get the number of images per class label. For example:
```bash
$ cvdata_analyze --format kitti --annotations /data/scissors/kitti --images /data/scissors/images
Label: scissors Count: 100
```

## Visualize annotations
In order to visualize images and corresponding annotations use the script
`cvdata/visualize.py` or the corresponding script entry point `cvdata_visualize`.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="cvdata",
version="0.0.6",
version="0.0.7",
author="James Adams",
author_email="[email protected]",
description="Tools for creating and manipulating computer vision datasets",
Expand Down
2 changes: 2 additions & 0 deletions src/cvdata/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# version of the cvdata package, should match with setup.py
__version__ = "0.0.7"
122 changes: 75 additions & 47 deletions src/cvdata/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xml.etree import ElementTree

import pandas as pd
import tensorflow as tf

import cvdata.common
from cvdata.utils import matching_ids
Expand Down Expand Up @@ -138,6 +139,26 @@ def count_labels(
raise ValueError(f"Unsupported annotation format: \"{annotation_format}\"")


# ------------------------------------------------------------------------------
def count_tfrecord_examples(
tfrecords_dir: str,
) -> int:
"""
Counts the total number of examples in a collection of TFRecord files.
:param tfrecords_dir: directory that is assumed to contain only TFRecord files
:return: the total number of examples in the collection of TFRecord files
found in the specified directory
"""

count = 0
for file_name in os.listdir(tfrecords_dir):
tfrecord_path = os.path.join(tfrecords_dir, file_name)
count += sum(1 for _ in tf.data.TFRecordDataset(tfrecord_path))

return count


# ------------------------------------------------------------------------------
def main():

Expand All @@ -151,7 +172,7 @@ def main():
)
args_parser.add_argument(
"--images",
required=True,
required=False,
type=str,
help="images directory path",
)
Expand All @@ -172,64 +193,71 @@ def main():
)
args = vars(args_parser.parse_args())

# the two dictionaries we'll build for final reporting
label_counts = {}
label_file_ids = {}
if args["format"] == "tfrecord":

if args["format"] == "openimages":
# count and report the examples in the collection of TFRecord files
examples_count = count_tfrecord_examples(args["annotations"])
print(f"Total number of examples: {examples_count}")

# read the OpenImages CSV into a pandas DataFrame
df_annotations = pd.read_csv(args["annotations"])
df_annotations = df_annotations[['ImageID', 'LabelName']]
else:
# the two dictionaries we'll build for final reporting
label_counts = {}
label_file_ids = {}

# TODO get another dataframe from the class descriptions and get the
# readable label names from there to map to the LabelName column
if args["format"] == "openimages":

# whittle it down to only the rows that match to image IDs
file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])]
df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)]
# read the OpenImages CSV into a pandas DataFrame
df_annotations = pd.read_csv(args["annotations"])
df_annotations = df_annotations[['ImageID', 'LabelName']]

# TODO populate the label counts and label file IDs dictionaries
# TODO get another dataframe from the class descriptions and get the
# readable label names from there to map to the LabelName column

else:
# whittle it down to only the rows that match to image IDs
file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])]
df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)]

# TODO populate the label counts and label file IDs dictionaries

annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]]
else:

# only annotations matching to the images are considered to be valid
file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg")
annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]]

for file_id in file_ids:
annotation_file_path = \
os.path.join(args["annotations"], file_id + annotation_ext)
# only annotations matching to the images are considered to be valid
file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg")

# get the images per label count
for label, count in count_labels(annotation_file_path, args["format"]).items():
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
for file_id in file_ids:
annotation_file_path = \
os.path.join(args["annotations"], file_id + annotation_ext)

# for each label found in the annotation file add this file ID
# to the set of file IDs corresponding to the label
if args["file_ids"]:
if label in label_file_ids:
# add this file ID to the existing set for the label
label_file_ids[label].add(file_id)
# get the images per label count
for label, count in count_labels(annotation_file_path, args["format"]).items():
if label in label_counts:
label_counts[label] += 1
else:
# first file ID seen for this label so create new set
label_file_ids[label] = {file_id}

# write the images per label counts
for label, count in label_counts.items():
print(f"Label: {label}\t\tCount: {count}")

# write the label ID files, if requested
if args["file_ids"]:
for label, file_ids_for_label in label_file_ids.items():
label_file_ids_path = os.path.join(args["file_ids"], label + ".txt")
with open(label_file_ids_path, "w") as label_file_ids_file:
for file_id in file_ids_for_label:
label_file_ids_file.write(f"{file_id}\n")
label_counts[label] = 1

# for each label found in the annotation file add this file ID
# to the set of file IDs corresponding to the label
if args["file_ids"]:
if label in label_file_ids:
# add this file ID to the existing set for the label
label_file_ids[label].add(file_id)
else:
# first file ID seen for this label so create new set
label_file_ids[label] = {file_id}

# write the images per label counts
for label, count in label_counts.items():
print(f"Label: {label}\t\tCount: {count}")

# write the label ID files, if requested
if args["file_ids"]:
for label, file_ids_for_label in label_file_ids.items():
label_file_ids_path = os.path.join(args["file_ids"], label + ".txt")
with open(label_file_ids_path, "w") as label_file_ids_file:
for file_id in file_ids_for_label:
label_file_ids_file.write(f"{file_id}\n")


# ------------------------------------------------------------------------------
Expand Down
17 changes: 17 additions & 0 deletions tests/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,20 @@ def test_count_labels(
assert label_counts["1"] == 2
assert label_counts["2"] == 1
assert label_counts["3"] == 1


# ------------------------------------------------------------------------------
@pytest.mark.usefixtures(
"data_dir",
)
def test_count_tfrecord_examples(
data_dir,
):
"""
Test for the cvdata.analyze.count_tfrecord_examples() function
:param data_dir: temporary directory into which test files will be loaded
"""
tfrecord_dir = os.path.join(str(data_dir), "tfrecord")
example_count = analyze.count_tfrecord_examples(tfrecord_dir)
assert example_count == 100
Binary file not shown.

0 comments on commit 1be3c0d

Please sign in to comment.