diff --git a/src/cvdata/analyze.py b/src/cvdata/analyze.py index ae32f7e..44e18e5 100644 --- a/src/cvdata/analyze.py +++ b/src/cvdata/analyze.py @@ -6,6 +6,7 @@ from xml.etree import ElementTree import pandas as pd +import tensorflow as tf import cvdata.common from cvdata.utils import matching_ids @@ -138,6 +139,26 @@ def count_labels( raise ValueError(f"Unsupported annotation format: \"{annotation_format}\"") +# ------------------------------------------------------------------------------ +def count_tfrecord_examples( + tfrecords_dir: str, +) -> int: + """ + Counts the total number of examples in a collection of TFRecord files. + + :param tfrecords_dir: directory that is assumed to contain only TFRecord files + :return: the total number of examples in the collection of TFRecord files + found in the specified directory + """ + + count = 0 + for file_name in os.listdir(tfrecords_dir): + tfrecord_path = os.path.join(tfrecords_dir, file_name) + count += sum(1 for _ in tf.data.TFRecordDataset(tfrecord_path)) + + return count + + # ------------------------------------------------------------------------------ def main(): @@ -172,64 +193,71 @@ def main(): ) args = vars(args_parser.parse_args()) - # the two dictionaries we'll build for final reporting - label_counts = {} - label_file_ids = {} + if args["format"] == "tfrecord": + + # count and report the examples in the collection of TFRecord files + examples_count = count_tfrecord_examples(args["annotations"]) + print(f"Total number of examples: {examples_count}") - if args["format"] == "openimages": + else: + # the two dictionaries we'll build for final reporting + label_counts = {} + label_file_ids = {} - # read the OpenImages CSV into a pandas DataFrame - df_annotations = pd.read_csv(args["annotations"]) - df_annotations = df_annotations[['ImageID', 'LabelName']] + if args["format"] == "openimages": - # TODO get another dataframe from the class descriptions and get the - # readable label names from there to map to the LabelName column + # read the OpenImages CSV into a pandas DataFrame + df_annotations = pd.read_csv(args["annotations"]) + df_annotations = df_annotations[['ImageID', 'LabelName']] - # whittle it down to only the rows that match to image IDs - file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])] - df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)] + # TODO get another dataframe from the class descriptions and get the + # readable label names from there to map to the LabelName column - # TODO populate the label counts and label file IDs dictionaries + # whittle it down to only the rows that match to image IDs + file_ids = [os.path.splitext(file_name)[0] for file_name in os.listdir(args["images"])] + df_annotations = df_annotations[df_annotations["ImageID"].isin(file_ids)] - else: + # TODO populate the label counts and label file IDs dictionaries - annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]] + else: - # only annotations matching to the images are considered to be valid - file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg") + annotation_ext = cvdata.common.FORMAT_EXTENSIONS[args["format"]] - for file_id in file_ids: - annotation_file_path = \ - os.path.join(args["annotations"], file_id + annotation_ext) + # only annotations matching to the images are considered to be valid + file_ids = matching_ids(args["annotations"], args["images"], annotation_ext, ".jpg") - # get the images per label count - for label, count in count_labels(annotation_file_path, args["format"]).items(): - if label in label_counts: - label_counts[label] += 1 - else: - label_counts[label] = 1 + for file_id in file_ids: + annotation_file_path = \ + os.path.join(args["annotations"], file_id + annotation_ext) - # for each label found in the annotation file add this file ID - # to the set of file IDs corresponding to the label - if args["file_ids"]: - if label in label_file_ids: - # add this file ID to the existing set for the label - label_file_ids[label].add(file_id) + # get the images per label count + for label, count in count_labels(annotation_file_path, args["format"]).items(): + if label in label_counts: + label_counts[label] += 1 else: - # first file ID seen for this label so create new set - label_file_ids[label] = {file_id} - - # write the images per label counts - for label, count in label_counts.items(): - print(f"Label: {label}\t\tCount: {count}") - - # write the label ID files, if requested - if args["file_ids"]: - for label, file_ids_for_label in label_file_ids.items(): - label_file_ids_path = os.path.join(args["file_ids"], label + ".txt") - with open(label_file_ids_path, "w") as label_file_ids_file: - for file_id in file_ids_for_label: - label_file_ids_file.write(f"{file_id}\n") + label_counts[label] = 1 + + # for each label found in the annotation file add this file ID + # to the set of file IDs corresponding to the label + if args["file_ids"]: + if label in label_file_ids: + # add this file ID to the existing set for the label + label_file_ids[label].add(file_id) + else: + # first file ID seen for this label so create new set + label_file_ids[label] = {file_id} + + # write the images per label counts + for label, count in label_counts.items(): + print(f"Label: {label}\t\tCount: {count}") + + # write the label ID files, if requested + if args["file_ids"]: + for label, file_ids_for_label in label_file_ids.items(): + label_file_ids_path = os.path.join(args["file_ids"], label + ".txt") + with open(label_file_ids_path, "w") as label_file_ids_file: + for file_id in file_ids_for_label: + label_file_ids_file.write(f"{file_id}\n") # ------------------------------------------------------------------------------