Skip to content

Commit

Permalink
new TFRecord examples counting function, including CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
monocongo committed Feb 7, 2020
1 parent 1d08f96 commit c9a6d69
Showing 1 changed file with 74 additions and 46 deletions.
120 changes: 74 additions & 46 deletions src/cvdata/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xml.etree import ElementTree

import pandas as pd
import tensorflow as tf

import cvdata.common
from cvdata.utils import matching_ids
Expand Down Expand Up @@ -138,6 +139,26 @@ def count_labels(
raise ValueError(f"Unsupported annotation format: \"{annotation_format}\"")


# ------------------------------------------------------------------------------
def count_tfrecord_examples(
tfrecords_dir: str,
) -> int:
"""
Counts the total number of examples in a collection of TFRecord files.
:param tfrecords_dir: directory that is assumed to contain only TFRecord files
:return: the total number of examples in the collection of TFRecord files
found in the specified directory
"""

count = 0
for file_name in os.listdir(tfrecords_dir):
tfrecord_path = os.path.join(tfrecords_dir, file_name)
count += sum(1 for _ in tf.data.TFRecordDataset(tfrecord_path))

return count


# ------------------------------------------------------------------------------
def main():

Expand Down Expand Up @@ -172,64 +193,71 @@ def main():
)
args = vars(args_parser.parse_args())

# the two dictionaries we'll build for final reporting
label_counts = {}
label_file_ids = {}
if args["format"] == "tfrecord":

# count and report the examples in the collection of TFRecord files
examples_count = count_tfrecord_examples(args["annotations"])
print(f"Total number of examples: {examples_count}")

if args["format"] == "openimages":
else:
# the two dictionaries we'll build for final reporting
label_counts = {}
label_file_ids = {}

# read the OpenImages CSV into a pandas DataFrame
df_annotations = pd.read_csv(args["annotations"])
df_annotations = df_annotations[['ImageID', 'LabelName']]
if args["format"] == "openimages":

# TODO get another dataframe from the class descriptions and get the
# readable label names from there to map to the LabelName column
# read the OpenImages CSV into a pandas DataFrame
df_annotations = pd.read_csv(args["annotations"])
df_annotations = df_annotations[['ImageID', 'LabelName']]

# whittle it down to only the rows that match to image IDs
file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])]
df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)]
# TODO get another dataframe from the class descriptions and get the
# readable label names from there to map to the LabelName column

# TODO populate the label counts and label file IDs dictionaries
# whittle it down to only the rows that match to image IDs
file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])]
df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)]

else:
# TODO populate the label counts and label file IDs dictionaries

annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]]
else:

# only annotations matching to the images are considered to be valid
file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg")
annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]]

for file_id in file_ids:
annotation_file_path = \
os.path.join(args["annotations"], file_id + annotation_ext)
# only annotations matching to the images are considered to be valid
file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg")

# get the images per label count
for label, count in count_labels(annotation_file_path, args["format"]).items():
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
for file_id in file_ids:
annotation_file_path = \
os.path.join(args["annotations"], file_id + annotation_ext)

# for each label found in the annotation file add this file ID
# to the set of file IDs corresponding to the label
if args["file_ids"]:
if label in label_file_ids:
# add this file ID to the existing set for the label
label_file_ids[label].add(file_id)
# get the images per label count
for label, count in count_labels(annotation_file_path, args["format"]).items():
if label in label_counts:
label_counts[label] += 1
else:
# first file ID seen for this label so create new set
label_file_ids[label] = {file_id}

# write the images per label counts
for label, count in label_counts.items():
print(f"Label: {label}\t\tCount: {count}")

# write the label ID files, if requested
if args["file_ids"]:
for label, file_ids_for_label in label_file_ids.items():
label_file_ids_path = os.path.join(args["file_ids"], label + ".txt")
with open(label_file_ids_path, "w") as label_file_ids_file:
for file_id in file_ids_for_label:
label_file_ids_file.write(f"{file_id}\n")
label_counts[label] = 1

# for each label found in the annotation file add this file ID
# to the set of file IDs corresponding to the label
if args["file_ids"]:
if label in label_file_ids:
# add this file ID to the existing set for the label
label_file_ids[label].add(file_id)
else:
# first file ID seen for this label so create new set
label_file_ids[label] = {file_id}

# write the images per label counts
for label, count in label_counts.items():
print(f"Label: {label}\t\tCount: {count}")

# write the label ID files, if requested
if args["file_ids"]:
for label, file_ids_for_label in label_file_ids.items():
label_file_ids_path = os.path.join(args["file_ids"], label + ".txt")
with open(label_file_ids_path, "w") as label_file_ids_file:
for file_id in file_ids_for_label:
label_file_ids_file.write(f"{file_id}\n")


# ------------------------------------------------------------------------------
Expand Down

0 comments on commit c9a6d69

Please sign in to comment.