diff --git a/cvdata/convert.py b/cvdata/convert.py index e7cda9a..6b2dcbb 100644 --- a/cvdata/convert.py +++ b/cvdata/convert.py @@ -10,11 +10,10 @@ import contextlib2 import cv2 -from object_detection.utils import dataset_util -from object_detection.dataset_tools import tf_record_creation_util import pandas as pd from PIL import Image import tensorflow as tf +from tensorflow.compat.v1.python_io import TFRecordWriter from tqdm import tqdm from cvdata.common import FORMAT_CHOICES @@ -119,6 +118,46 @@ def _dataset_bbox_examples( return pd.DataFrame(bboxes, columns=column_names) +# ------------------------------------------------------------------------------ +def _int64_feature( + value: int, +) -> tf.train.Feature: + + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +# ------------------------------------------------------------------------------ +def _int64_list_feature( + value: List[int], +) -> tf.train.Feature: + + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + + +# ------------------------------------------------------------------------------ +def _bytes_feature( + value: str, +) -> tf.train.Feature: + + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +# ------------------------------------------------------------------------------ +def _bytes_list_feature( + value: List[str], +) -> tf.train.Feature: + + return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) + + +# ------------------------------------------------------------------------------ +def _float_list_feature( + value: List[float], +) -> tf.train.Feature: + + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + # ------------------------------------------------------------------------------ def _create_tf_example( label_indices: Dict, @@ -163,18 +202,18 @@ def _create_tf_example( # build the Example from the lists of coordinates, class labels/indices, etc. tf_example = tf.train.Example(features=tf.train.Features(feature={ - 'image/height': dataset_util.int64_feature(height), - 'image/width': dataset_util.int64_feature(width), - 'image/filename': dataset_util.bytes_feature(filename), - 'image/source_id': dataset_util.bytes_feature(filename), - 'image/encoded': dataset_util.bytes_feature(img_bytes), - 'image/format': dataset_util.bytes_feature(image_format), - 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), - 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), - 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), - 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), - 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), - 'image/object/class/label': dataset_util.int64_list_feature(classes), + 'image/height': _int64_feature(height), + 'image/width': _int64_feature(width), + 'image/filename': _bytes_feature(filename), + 'image/source_id': _bytes_feature(filename), + 'image/encoded': _bytes_feature(img_bytes), + 'image/format': _bytes_feature(image_format), + 'image/object/bbox/xmin': _float_list_feature(xmins), + 'image/object/bbox/xmax': _float_list_feature(xmaxs), + 'image/object/bbox/ymin': _float_list_feature(ymins), + 'image/object/bbox/ymax': _float_list_feature(ymaxs), + 'image/object/class/text': _bytes_list_feature(classes_text), + 'image/object/class/label': _int64_list_feature(classes), })) return tf_example @@ -212,6 +251,38 @@ def _generate_label_map( return label_indices +# ------------------------------------------------------------------------------ +def _open_sharded_output_tfrecords( + exit_stack: contextlib2.ExitStack, + base_path: str, + num_shards: int, +) -> List: + """ + Opens all TFRecord shards for writing and adds them to an exit stack. + + Modified from original code in the TensorFlow Object Detection API: + https://github.com/tensorflow/models/object-detection/research/object_detection/dataset_tools/tf_record_creation_util.py + + :param exit_stack: a contextlib2.ExitStack used to automatically close the + TFRecords opened in this function + :param base_path: the base file path for all shards + :param num_shards: number of shards + :return: a list of opened TFRecord shard files (position k in the list + corresponds to shard k) + """ + tf_record_output_filenames = [ + f'{base_path}-{str(idx).zfill(5)}-of-{str(num_shards).zfill(5)}' + for idx in range(num_shards) + ] + + tfrecords = [ + exit_stack.enter_context(TFRecordWriter(file_name)) + for file_name in tf_record_output_filenames + ] + + return tfrecords + + # ------------------------------------------------------------------------------ def _to_tfrecord( images_dir: str, @@ -257,7 +328,7 @@ def _to_tfrecord( # write the TFRecords into the specified number of "shard" files with contextlib2.ExitStack() as tf_record_close_stack: output_tfrecords = \ - tf_record_creation_util.open_sharded_output_tfrecords( + _open_sharded_output_tfrecords( tf_record_close_stack, tfrecord_path, total_shards, diff --git a/cvdata/mask.py b/cvdata/mask.py index ff5cfe9..c1d24a7 100644 --- a/cvdata/mask.py +++ b/cvdata/mask.py @@ -181,17 +181,42 @@ def masked_dataset_to_tfrecords( # create a mapping of base file names and subsets of file IDs if train_pct < 1.0: + # get the correct file name prefix for the TFRecord files + # based on the presence of a specified file base name + tfrecord_file_prefix_train = "train" + tfrecord_file_prefix_valid = "valid" + if dataset_base_name != "": + tfrecord_file_prefix_train = tfrecord_file_prefix_train + "_" + dataset_base_name + tfrecord_file_prefix_valid = tfrecord_file_prefix_valid + "_" + dataset_base_name + + # get the split index to use for splitting into train/valid sets split_index = int(len(file_ids) * train_pct) + + # map the file prefixes to the sets of file IDs for the split sections split_names_to_ids = { - "train_" + dataset_base_name: file_ids[:split_index], - "valid_" + dataset_base_name: file_ids[split_index:], + tfrecord_file_prefix_train: file_ids[:split_index], + tfrecord_file_prefix_valid: file_ids[split_index:], } + + # report the number of samples in each split section + _logger.info(f"TFRecord dataset contains {len(file_ids[:split_index])} training samples") + _logger.info(f"TFRecord dataset contains {len(file_ids[split_index:])} validation samples") + else: # we'll just have one base file name mapped to all file IDs + if "" == dataset_base_name: + tfrecord_file_prefix = "tfrecord" + else: + tfrecord_file_prefix = dataset_base_name + + # map the file prefixes to the set of file IDs split_names_to_ids = { - dataset_base_name: file_ids, + tfrecord_file_prefix: file_ids, } + # report the number of samples + _logger.info(f"TFRecord dataset contains {len(file_ids)} samples (no train/valid split)") + # create an iterable of arguments that will be mapped to concurrent future processes args_iterable = [] for base_name, file_ids in split_names_to_ids.items(): @@ -419,7 +444,7 @@ def main(): "--base_name", required=False, type=str, - default="tfrecord", + default="", help="base name of the TFRecord files", ) args = vars(args_parser.parse_args()) diff --git a/requirements.txt b/requirements.txt index d6d945d..ae0fc7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,14 @@ absl-py==0.9.0 -astor==0.8.1 +astor==0.7.1 attrs==19.2.0 backcall==0.1.0 bleach==3.1.0 -boto3==1.10.48 -botocore==1.13.48 -cachetools==4.0.0 +boto3==1.11.9 +botocore==1.14.9 certifi==2019.11.28 chardet==3.0.4 contextlib2==0.6.0 +cvdata==0.0.6 cycler==0.10.0 Cython==0.29.13 decorator==4.4.0 @@ -16,10 +16,8 @@ defusedxml==0.6.0 docutils==0.15.2 entrypoints==0.3 gast==0.2.2 -google-auth==1.10.0 -google-auth-oauthlib==0.4.1 google-pasta==0.1.7 -grpcio==1.26.0 +grpcio==1.23.0 h5py==2.10.0 idna==2.8 ImageHash==4.0 @@ -28,7 +26,6 @@ ipython==7.8.0 ipython-genutils==0.2.0 ipywidgets==7.5.1 jedi==0.15.1 -Jinja2==2.10.3 jmespath==0.9.4 jsonschema==3.0.2 jupyter==1.0.0 @@ -47,11 +44,9 @@ nbconvert==5.6.0 nbformat==4.4.0 notebook==6.0.1 numpy==1.17.2 -oauthlib==3.1.0 --e git+https://github.com/tensorflow/models.git#egg=object_detection&subdirectory=research opencv-python==4.1.2.30 opt-einsum==3.1.0 -pandas==0.25.3 +pandas==1.0.0rc0 pandocfilters==1.4.2 parso==0.5.1 pexpect==4.7.0 @@ -61,8 +56,6 @@ prometheus-client==0.7.1 prompt-toolkit==2.0.10 protobuf==3.10.0 ptyprocess==0.6.0 -pyasn1==0.4.8 -pyasn1-modules==0.2.7 Pygments==2.4.2 pyparsing==2.4.2 pyrsistent==0.15.4 @@ -72,23 +65,21 @@ PyWavelets==1.1.1 pyzmq==18.1.0 qtconsole==4.5.5 requests==2.22.0 -requests-oauthlib==1.3.0 -rsa==4.0 -s3transfer==0.2.1 +s3transfer==0.3.2 scipy==1.4.1 Send2Trash==1.5.0 six==1.12.0 -tensorboard==2.1.0 -tensorflow==2.1.0rc2 +tensorboard==2.0.0 +tensorflow==2.0.0 tensorflow-cpu==1.15.0rc2 -tensorflow-estimator==2.1.0 +tensorflow-estimator==2.0.0 termcolor==1.1.0 terminado==0.8.2 testpath==0.4.2 tornado==6.0.3 -tqdm==4.41.1 +tqdm==4.42.0 traitlets==4.3.3 -urllib3==1.25.7 +urllib3==1.25.8 wcwidth==0.1.7 webencodings==0.5.1 Werkzeug==0.16.0 diff --git a/tests/test_resize.py b/tests/test_resize.py index cd3d64b..754f8e9 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -64,16 +64,76 @@ def test_resize( np.testing.assert_equal(resized_image, expected_resized_image, err_msg="Image not resized as expected") + + +# ------------------------------------------------------------------------------ +@pytest.mark.usefixtures( + "data_dir", +) +def test_resize_image_label( + data_dir, +): + """ + Test for the cvdata.resize.resize_image_label() function + + :param data_dir: temporary directory into which test files will be loaded + """ + file_id = "image" + image_ext = ".jpg" + image_file_name = f"{file_id}{image_ext}" + kitti_ext = ".txt" + kitti_file_name = f"{file_id}{kitti_ext}" + pascal_ext = ".xml" + pascal_file_name = f"{file_id}{pascal_ext}" + + # make a directory to hold our resized files + resized_dir = os.path.join(str(data_dir), "resized") + os.makedirs(resized_dir, exist_ok=True) + + new_width = 240 + new_height = 720 + expected_resized_file_id = f"{file_id}_w{new_width}_h{new_height}" + expected_resized_image_file_name = f"{expected_resized_file_id}{image_ext}" + expected_resized_kitti_file_name = f"{expected_resized_file_id}{kitti_ext}" + expected_resized_pascal_file_name = f"{expected_resized_file_id}{pascal_ext}" + expected_resized_image_file_path = os.path.join(str(data_dir), expected_resized_image_file_name) + expected_resized_kitti_file_path = os.path.join(str(data_dir), expected_resized_kitti_file_name) + expected_resized_pascal_file_path = os.path.join(str(data_dir), expected_resized_pascal_file_name) + + # confirm that we can resize as expected for a KITTI annotated image + resize.resize_image_label( + file_id, + image_ext, + kitti_ext, + data_dir, + data_dir, + resized_dir, + resized_dir, + new_width, + new_height, + "kitti", + ) + resized_image_file_path = os.path.join(resized_dir, image_file_name) + resized_image = cv2.imread(resized_image_file_path) + expected_resized_image = cv2.imread(expected_resized_image_file_path) + np.testing.assert_equal(resized_image, + expected_resized_image, + err_msg="Image not resized as expected") resized_kitti_file_path = os.path.join(resized_dir, kitti_file_name) assert text_files_equal(resized_kitti_file_path, expected_resized_kitti_file_path) - # confirm that resizing occurred as expected for a PASCAL annotated image - resize.resize_image( - file_id + image_ext, + # confirm that we can resize as expected for a PASCAL annotated image + resize.resize_image_label( + file_id, + image_ext, + pascal_ext, + data_dir, data_dir, resized_dir, + resized_dir, new_width, new_height, + "pascal", ) resized_image_file_path = os.path.join(resized_dir, image_file_name) resized_image = cv2.imread(resized_image_file_path) diff --git a/tests/test_resize/image.txt b/tests/test_resize/image.txt index 9024f17..a90c438 100644 --- a/tests/test_resize/image.txt +++ b/tests/test_resize/image.txt @@ -1 +1 @@ -handgun 0.0 0 0.0 505.0 316.0 616.0 441.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 +handgun 0.0 0 0.0 222.0 139.0 271.0 194.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 diff --git a/tests/test_resize/image_w240_h720.txt b/tests/test_resize/image_w240_h720.txt index acf526b..e1d31d1 100644 --- a/tests/test_resize/image_w240_h720.txt +++ b/tests/test_resize/image_w240_h720.txt @@ -1 +1 @@ -handgun 0.0 0 0.0 269 168 239 235 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \ No newline at end of file +handgun 0.0 0 0.0 118 74 144 103 0.0 0.0 0.0 0.0 0.0 0.0 0.0