Skip to content

Commit

Permalink
Merge pull request #122 from monocongo/issue_68_broken_tests
Browse files Browse the repository at this point in the history
Broken tests fixed, removed TFOD API dependency
  • Loading branch information
monocongo authored Jan 27, 2020
2 parents 127662a + a4289c3 commit 2df0cdd
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 45 deletions.
101 changes: 86 additions & 15 deletions cvdata/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

import contextlib2
import cv2
from object_detection.utils import dataset_util
from object_detection.dataset_tools import tf_record_creation_util
import pandas as pd
from PIL import Image
import tensorflow as tf
from tensorflow.compat.v1.python_io import TFRecordWriter
from tqdm import tqdm

from cvdata.common import FORMAT_CHOICES
Expand Down Expand Up @@ -119,6 +118,46 @@ def _dataset_bbox_examples(
return pd.DataFrame(bboxes, columns=column_names)


# ------------------------------------------------------------------------------
def _int64_feature(
value: int,
) -> tf.train.Feature:

return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# ------------------------------------------------------------------------------
def _int64_list_feature(
value: List[int],
) -> tf.train.Feature:

return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


# ------------------------------------------------------------------------------
def _bytes_feature(
value: str,
) -> tf.train.Feature:

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# ------------------------------------------------------------------------------
def _bytes_list_feature(
value: List[str],
) -> tf.train.Feature:

return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


# ------------------------------------------------------------------------------
def _float_list_feature(
value: List[float],
) -> tf.train.Feature:

return tf.train.Feature(float_list=tf.train.FloatList(value=value))


# ------------------------------------------------------------------------------
def _create_tf_example(
label_indices: Dict,
Expand Down Expand Up @@ -163,18 +202,18 @@ def _create_tf_example(

# build the Example from the lists of coordinates, class labels/indices, etc.
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(img_bytes),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/filename': _bytes_feature(filename),
'image/source_id': _bytes_feature(filename),
'image/encoded': _bytes_feature(img_bytes),
'image/format': _bytes_feature(image_format),
'image/object/bbox/xmin': _float_list_feature(xmins),
'image/object/bbox/xmax': _float_list_feature(xmaxs),
'image/object/bbox/ymin': _float_list_feature(ymins),
'image/object/bbox/ymax': _float_list_feature(ymaxs),
'image/object/class/text': _bytes_list_feature(classes_text),
'image/object/class/label': _int64_list_feature(classes),
}))

return tf_example
Expand Down Expand Up @@ -212,6 +251,38 @@ def _generate_label_map(
return label_indices


# ------------------------------------------------------------------------------
def _open_sharded_output_tfrecords(
exit_stack: contextlib2.ExitStack,
base_path: str,
num_shards: int,
) -> List:
"""
Opens all TFRecord shards for writing and adds them to an exit stack.
Modified from original code in the TensorFlow Object Detection API:
https://github.com/tensorflow/models/object-detection/research/object_detection/dataset_tools/tf_record_creation_util.py
:param exit_stack: a contextlib2.ExitStack used to automatically close the
TFRecords opened in this function
:param base_path: the base file path for all shards
:param num_shards: number of shards
:return: a list of opened TFRecord shard files (position k in the list
corresponds to shard k)
"""
tf_record_output_filenames = [
f'{base_path}-{str(idx).zfill(5)}-of-{str(num_shards).zfill(5)}'
for idx in range(num_shards)
]

tfrecords = [
exit_stack.enter_context(TFRecordWriter(file_name))
for file_name in tf_record_output_filenames
]

return tfrecords


# ------------------------------------------------------------------------------
def _to_tfrecord(
images_dir: str,
Expand Down Expand Up @@ -257,7 +328,7 @@ def _to_tfrecord(
# write the TFRecords into the specified number of "shard" files
with contextlib2.ExitStack() as tf_record_close_stack:
output_tfrecords = \
tf_record_creation_util.open_sharded_output_tfrecords(
_open_sharded_output_tfrecords(
tf_record_close_stack,
tfrecord_path,
total_shards,
Expand Down
33 changes: 29 additions & 4 deletions cvdata/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,42 @@ def masked_dataset_to_tfrecords(

# create a mapping of base file names and subsets of file IDs
if train_pct < 1.0:
# get the correct file name prefix for the TFRecord files
# based on the presence of a specified file base name
tfrecord_file_prefix_train = "train"
tfrecord_file_prefix_valid = "valid"
if dataset_base_name != "":
tfrecord_file_prefix_train = tfrecord_file_prefix_train + "_" + dataset_base_name
tfrecord_file_prefix_valid = tfrecord_file_prefix_valid + "_" + dataset_base_name

# get the split index to use for splitting into train/valid sets
split_index = int(len(file_ids) * train_pct)

# map the file prefixes to the sets of file IDs for the split sections
split_names_to_ids = {
"train_" + dataset_base_name: file_ids[:split_index],
"valid_" + dataset_base_name: file_ids[split_index:],
tfrecord_file_prefix_train: file_ids[:split_index],
tfrecord_file_prefix_valid: file_ids[split_index:],
}

# report the number of samples in each split section
_logger.info(f"TFRecord dataset contains {len(file_ids[:split_index])} training samples")
_logger.info(f"TFRecord dataset contains {len(file_ids[split_index:])} validation samples")

else:
# we'll just have one base file name mapped to all file IDs
if "" == dataset_base_name:
tfrecord_file_prefix = "tfrecord"
else:
tfrecord_file_prefix = dataset_base_name

# map the file prefixes to the set of file IDs
split_names_to_ids = {
dataset_base_name: file_ids,
tfrecord_file_prefix: file_ids,
}

# report the number of samples
_logger.info(f"TFRecord dataset contains {len(file_ids)} samples (no train/valid split)")

# create an iterable of arguments that will be mapped to concurrent future processes
args_iterable = []
for base_name, file_ids in split_names_to_ids.items():
Expand Down Expand Up @@ -419,7 +444,7 @@ def main():
"--base_name",
required=False,
type=str,
default="tfrecord",
default="",
help="base name of the TFRecord files",
)
args = vars(args_parser.parse_args())
Expand Down
33 changes: 12 additions & 21 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
absl-py==0.9.0
astor==0.8.1
astor==0.7.1
attrs==19.2.0
backcall==0.1.0
bleach==3.1.0
boto3==1.10.48
botocore==1.13.48
cachetools==4.0.0
boto3==1.11.9
botocore==1.14.9
certifi==2019.11.28
chardet==3.0.4
contextlib2==0.6.0
cvdata==0.0.6
cycler==0.10.0
Cython==0.29.13
decorator==4.4.0
defusedxml==0.6.0
docutils==0.15.2
entrypoints==0.3
gast==0.2.2
google-auth==1.10.0
google-auth-oauthlib==0.4.1
google-pasta==0.1.7
grpcio==1.26.0
grpcio==1.23.0
h5py==2.10.0
idna==2.8
ImageHash==4.0
Expand All @@ -28,7 +26,6 @@ ipython==7.8.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.15.1
Jinja2==2.10.3
jmespath==0.9.4
jsonschema==3.0.2
jupyter==1.0.0
Expand All @@ -47,11 +44,9 @@ nbconvert==5.6.0
nbformat==4.4.0
notebook==6.0.1
numpy==1.17.2
oauthlib==3.1.0
-e git+https://github.com/tensorflow/models.git#egg=object_detection&subdirectory=research
opencv-python==4.1.2.30
opt-einsum==3.1.0
pandas==0.25.3
pandas==1.0.0rc0
pandocfilters==1.4.2
parso==0.5.1
pexpect==4.7.0
Expand All @@ -61,8 +56,6 @@ prometheus-client==0.7.1
prompt-toolkit==2.0.10
protobuf==3.10.0
ptyprocess==0.6.0
pyasn1==0.4.8
pyasn1-modules==0.2.7
Pygments==2.4.2
pyparsing==2.4.2
pyrsistent==0.15.4
Expand All @@ -72,23 +65,21 @@ PyWavelets==1.1.1
pyzmq==18.1.0
qtconsole==4.5.5
requests==2.22.0
requests-oauthlib==1.3.0
rsa==4.0
s3transfer==0.2.1
s3transfer==0.3.2
scipy==1.4.1
Send2Trash==1.5.0
six==1.12.0
tensorboard==2.1.0
tensorflow==2.1.0rc2
tensorboard==2.0.0
tensorflow==2.0.0
tensorflow-cpu==1.15.0rc2
tensorflow-estimator==2.1.0
tensorflow-estimator==2.0.0
termcolor==1.1.0
terminado==0.8.2
testpath==0.4.2
tornado==6.0.3
tqdm==4.41.1
tqdm==4.42.0
traitlets==4.3.3
urllib3==1.25.7
urllib3==1.25.8
wcwidth==0.1.7
webencodings==0.5.1
Werkzeug==0.16.0
Expand Down
66 changes: 63 additions & 3 deletions tests/test_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,76 @@ def test_resize(
np.testing.assert_equal(resized_image,
expected_resized_image,
err_msg="Image not resized as expected")


# ------------------------------------------------------------------------------
@pytest.mark.usefixtures(
"data_dir",
)
def test_resize_image_label(
data_dir,
):
"""
Test for the cvdata.resize.resize_image_label() function
:param data_dir: temporary directory into which test files will be loaded
"""
file_id = "image"
image_ext = ".jpg"
image_file_name = f"{file_id}{image_ext}"
kitti_ext = ".txt"
kitti_file_name = f"{file_id}{kitti_ext}"
pascal_ext = ".xml"
pascal_file_name = f"{file_id}{pascal_ext}"

# make a directory to hold our resized files
resized_dir = os.path.join(str(data_dir), "resized")
os.makedirs(resized_dir, exist_ok=True)

new_width = 240
new_height = 720
expected_resized_file_id = f"{file_id}_w{new_width}_h{new_height}"
expected_resized_image_file_name = f"{expected_resized_file_id}{image_ext}"
expected_resized_kitti_file_name = f"{expected_resized_file_id}{kitti_ext}"
expected_resized_pascal_file_name = f"{expected_resized_file_id}{pascal_ext}"
expected_resized_image_file_path = os.path.join(str(data_dir), expected_resized_image_file_name)
expected_resized_kitti_file_path = os.path.join(str(data_dir), expected_resized_kitti_file_name)
expected_resized_pascal_file_path = os.path.join(str(data_dir), expected_resized_pascal_file_name)

# confirm that we can resize as expected for a KITTI annotated image
resize.resize_image_label(
file_id,
image_ext,
kitti_ext,
data_dir,
data_dir,
resized_dir,
resized_dir,
new_width,
new_height,
"kitti",
)
resized_image_file_path = os.path.join(resized_dir, image_file_name)
resized_image = cv2.imread(resized_image_file_path)
expected_resized_image = cv2.imread(expected_resized_image_file_path)
np.testing.assert_equal(resized_image,
expected_resized_image,
err_msg="Image not resized as expected")
resized_kitti_file_path = os.path.join(resized_dir, kitti_file_name)
assert text_files_equal(resized_kitti_file_path, expected_resized_kitti_file_path)

# confirm that resizing occurred as expected for a PASCAL annotated image
resize.resize_image(
file_id + image_ext,
# confirm that we can resize as expected for a PASCAL annotated image
resize.resize_image_label(
file_id,
image_ext,
pascal_ext,
data_dir,
data_dir,
resized_dir,
resized_dir,
new_width,
new_height,
"pascal",
)
resized_image_file_path = os.path.join(resized_dir, image_file_name)
resized_image = cv2.imread(resized_image_file_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_resize/image.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
handgun 0.0 0 0.0 505.0 316.0 616.0 441.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
handgun 0.0 0 0.0 222.0 139.0 271.0 194.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 changes: 1 addition & 1 deletion tests/test_resize/image_w240_h720.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
handgun 0.0 0 0.0 269 168 239 235 0.0 0.0 0.0 0.0 0.0 0.0 0.0
handgun 0.0 0 0.0 118 74 144 103 0.0 0.0 0.0 0.0 0.0 0.0 0.0

0 comments on commit 2df0cdd

Please sign in to comment.