Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New functionality for building TFRecords from Darknet (YOLO) datasets #140

Merged
merged 9 commits into from
Feb 6, 2020
183 changes: 156 additions & 27 deletions src/cvdata/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
import contextlib2
import cv2
import pandas as pd
from PIL import Image
# from PIL import Image
import six
import tensorflow as tf
from tensorflow.compat.v1.python_io import TFRecordWriter
from tqdm import tqdm

from cvdata.common import FORMAT_CHOICES
from cvdata.utils import image_dimensions, matching_ids
from cvdata.utils import darknet_indices_to_labels, image_dimensions, matching_ids


# ------------------------------------------------------------------------------
Expand All @@ -35,12 +36,15 @@ def _dataset_bbox_examples(
images_dir: str,
annotations_dir: str,
annotation_format: str,
darknet_labels: str = None,
) -> pd.DataFrame:
"""

:param images_dir: directory containing the dataset's *.jpg image files
:param annotations_dir: directory containing the dataset's annotation files
:param annotation_format: currently supported: "kitti" and "pascal"
:param annotation_format: currently supported: "darknet", "kitti", and "pascal"
:param darknet_labels: path to the class labels file corresponding to Darknet
(YOLO) annotation files, only necessary if using "darknet" annotation format
:return: pandas DataFrame with rows corresponding to the dataset's bounding boxes
"""

Expand Down Expand Up @@ -79,7 +83,7 @@ def _dataset_bbox_examples(

elif annotation_format == "kitti":

# get the file IDs for all matching image/PASCAL pairs (i.e. the dataset)
# get the file IDs for all matching image/KITTI pairs (i.e. the dataset)
annotation_ext = ".txt"
for file_id in matching_ids(
annotations_dir,
Expand All @@ -97,16 +101,65 @@ def _dataset_bbox_examples(
kitti_path = os.path.join(annotations_dir, file_id + annotation_ext)
with open(kitti_path, "r") as kitti_file:
for line in kitti_file:
kitti_box = line.split()
darknet_box = line.split()
bbox_values = (
image_file_name,
width,
height,
kitti_box[0],
kitti_box[4],
kitti_box[5],
kitti_box[6],
kitti_box[7],
darknet_box[0],
darknet_box[4],
darknet_box[5],
darknet_box[6],
darknet_box[7],
)
bboxes.append(bbox_values)

elif annotation_format == "darknet":

# read class labels into index/label dictionary
darknet_index_labels = darknet_indices_to_labels(darknet_labels)

# get the file IDs for all matching image/Darknet pairs (i.e. the dataset)
annotation_ext = ".txt"
file_ids = matching_ids(
annotations_dir,
images_dir,
annotation_ext,
image_ext,
)

# get the bounding boxes from the annotation files
_logger.info("Extracting bounding box info from Darknet annotations...")
for file_id in tqdm(file_ids):
# get the image dimensions from the image file since this
# info is not present in the corresponding KITTI annotation
image_file_name = file_id + image_ext
image_path = os.path.join(images_dir, image_file_name)
width, height, _ = image_dimensions(image_path)

# add all bounding boxes from the Darknet file to the list of boxes
darknet_path = os.path.join(annotations_dir, file_id + annotation_ext)
with open(darknet_path, "r") as darknet_file:
for line in darknet_file:
darknet_box = line.split()
label_index = int(darknet_box[0])
# only use annotations corresponding to the specified labels
if label_index not in darknet_index_labels:
# skip this annotation line
continue
center_x = float(darknet_box[1]) * width
center_y = float(darknet_box[2]) * height
box_width = float(darknet_box[3]) * width
box_height = float(darknet_box[4]) * height
bbox_values = (
image_file_name,
width,
height,
darknet_index_labels[label_index],
int(center_x - (box_width / 2)),
int(center_y - (box_height / 2)),
int(center_x + (box_width / 2)),
int(center_y + (box_height / 2)),
)
bboxes.append(bbox_values)

Expand Down Expand Up @@ -144,6 +197,23 @@ def _bytes_feature(

# ------------------------------------------------------------------------------
def _bytes_list_feature(
values: str,
) -> tf.train.Feature:
"""
Returns a TF-Feature of bytes.

:param values a string
:return TF-Feature of bytes
"""

def norm2bytes(value):
return value.encode() if isinstance(value, str) and six.PY3 else value

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))


# ------------------------------------------------------------------------------
def _string_bytes_list_feature(
value: List[str],
) -> tf.train.Feature:

Expand Down Expand Up @@ -174,14 +244,14 @@ def _create_tf_example(
:return: TensorFlow Example object corresponding to the group of annotations
"""

# read the image into a bytes object, get the dimensions
image = Image.open(os.path.join(images_dir, group.filename))
img_bytes = image.tobytes()
width, height = image.size
# read the image
image_file_name = group.filename
image_path = os.path.join(images_dir, group.filename)
image_data = tf.io.gfile.GFile(image_path, 'rb').read()
width, height, _ = image_dimensions(image_path)

# lists of bounding box values for the example
filename = group.filename.encode('utf8')
image_format = b'jpg'
xmins = []
xmaxs = []
ymins = []
Expand All @@ -204,15 +274,15 @@ def _create_tf_example(
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/filename': _bytes_feature(filename),
'image/source_id': _bytes_feature(filename),
'image/encoded': _bytes_feature(img_bytes),
'image/format': _bytes_feature(image_format),
'image/encoded': _bytes_list_feature(image_data),
'image/filename': _bytes_list_feature(image_file_name),
'image/format': _bytes_list_feature('jpeg'),
'image/object/bbox/xmin': _float_list_feature(xmins),
'image/object/bbox/xmax': _float_list_feature(xmaxs),
'image/object/bbox/ymin': _float_list_feature(ymins),
'image/object/bbox/ymax': _float_list_feature(ymaxs),
'image/object/class/text': _bytes_list_feature(classes_text),
'image/object/class/text': _string_bytes_list_feature(classes_text),
'image/object/class/label': _int64_list_feature(classes),
}))

Expand All @@ -228,7 +298,7 @@ def _generate_label_map(

:param annotations_df: pandas DataFrame with rows for annotations, should
contain a column named "class" which contains the label text
:param labels_path: path to label map prototxt file that will be written
:param labels_path: path to class labels prototxt file to be written
:return: dictionary of labels to indices represented by the label map
"""

Expand Down Expand Up @@ -288,35 +358,39 @@ def _to_tfrecord(
images_dir: str,
annotations_dir: str,
annotation_format: str,
labels_path: str,
tf_labels_path: str,
tfrecord_path: str,
total_shards: int,
darknet_labels_path: str = None,
):
"""
Create TFRecord file(s) from an annotated dataset.

:param images_dir: directory containing the dataset's image files
:param annotations_dir: directory containing the dataset's annotation files
:param annotation_format:
:param labels_path: path to the label map prototext file that corresponds to
:param tf_labels_path: path to the label map prototext file that corresponds to
the TFRecord files, and which will be generated by this function (will be
overwritten if already exists)
:param tfrecord_path: base TFRecord file path, files generated will have this
as the base path with shard numbers at the end, for example if using 2 total
shards then the resulting files will be <tfrecord_path>-00000-of-00002
and <tfrecord_path>-00001-of-00002
:param total_shards: number of shards over which to spread the records
:param darknet_labels_path: path to the class labels file corresponding to Darknet
(YOLO) annotation files, only necessary if using "darknet" annotation format
"""

# get the annotation "examples" as a pandas DataFrame
examples_df = _dataset_bbox_examples(
images_dir,
annotations_dir,
annotation_format,
darknet_labels_path,
)

# generate the prototext label map file
label_indices = _generate_label_map(examples_df, labels_path)
label_indices = _generate_label_map(examples_df, tf_labels_path)

# group the annotation examples by corresponding file name
data = namedtuple("data", ["filename", "object"])
Expand Down Expand Up @@ -374,6 +448,43 @@ def kitti_to_tfrecord(
)


# ------------------------------------------------------------------------------
def darknet_to_tfrecord(
images_dir: str,
darknet_dir: str,
darknet_labels_path: str,
tf_labels_path: str,
tfrecord_path: str,
total_shards: int,
):
"""
Create TFRecord file(s) from a Darknet (YOLO) format annotated dataset.

:param images_dir: directory containing the dataset's image files
:param darknet_dir: directory containing the dataset's annotation files
:param labels_path: path to the label map prototext file that corresponds to
the TFRecord files, and which will be generated by this function (will be
overwritten if already exists)
:param tfrecord_path: base TFRecord file path, files generated will have this
as the base path with shard numbers at the end, for example if using 2 total
shards then the resulting files will be <tfrecord_path>-00000-of-00002
and <tfrecord_path>-00001-of-00002
:param total_shards: number of shards over which to spread the records
"""

_logger.info("Converting images and annotations in Darknet (YOLO) format to TFRecord(s)")

return _to_tfrecord(
images_dir,
darknet_dir,
"darknet",
tf_labels_path,
tfrecord_path,
total_shards,
darknet_labels_path,
)


# ------------------------------------------------------------------------------
def pascal_to_tfrecord(
images_dir: str,
Expand Down Expand Up @@ -417,8 +528,8 @@ def kitti_to_darknet(
darknet_labels: str,
):
"""
Creates equivalent Darknet annotation files corresponding to a dataset with
KITTI annotations.
Creates equivalent Darknet (YOLO) annotation files corresponding to a dataset
with KITTI annotations.

:param images_dir: directory containing the dataset's images
:param kitti_dir: directory containing the dataset's KITTI annotation files
Expand Down Expand Up @@ -907,8 +1018,11 @@ def main():
required=False,
type=str,
help="file name of the labels file that will correspond to the label "
"indices used in the Darknet annotation files, to be written "
"in the Darknet directory",
"indices used in the Darknet annotation files, should be a file "
"name to be written in the Darknet annotations directory if "
"converting to Darknet (YOLO) format (--out_format == \"darknet\"), "
"or the path to file to be read from if converting from Darknet "
"(YOLO) format (--in_format == \"darknet\")"
)
args = vars(args_parser.parse_args())

Expand Down Expand Up @@ -990,6 +1104,21 @@ def main():
"Unsupported format conversion: "
f"{args['in_format']} to {args['out_format']}",
)
elif args["in_format"] == "darknet":
if args["out_format"] == "tfrecord":
darknet_to_tfrecord(
args["images_dir"],
args["annotations_dir"],
args["darknet_labels"],
args["tf_label_map"],
args["out_dir"],
args["tf_shards"],
)
else:
raise ValueError(
"Unsupported format conversion: "
f"{args['in_format']} to {args['out_format']}",
)
else:
raise ValueError(
"Unsupported format conversion: "
Expand Down
41 changes: 2 additions & 39 deletions src/cvdata/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import tqdm

from cvdata.common import FORMAT_CHOICES, FORMAT_EXTENSIONS
from cvdata.utils import matching_ids
from cvdata.utils import darknet_indices_to_labels, matching_ids


# ------------------------------------------------------------------------------
Expand All @@ -20,43 +20,6 @@
_logger = logging.getLogger(__name__)


# ------------------------------------------------------------------------------
def _darknet_indices_to_labels(
darknet_labels_path: str,
) -> Dict:
"""
Parses a Darknet (YOLO) annotation labels file into a dictionary. The labels
file is expected to contain a single class label per line, and the resulting
dictionary will contain integer keys beginning at 0, so the first class label
will be the value for key 0, the second class label will be the value for key
1, etc. For example, the labels file with the following lines:

dog
cat
panda

will result in the following indices to labels dictionary:

{ 0: "dog", 1: "cat", 2: "panda" }

:param darknet_labels_path: path to the file containing labels used in
the Darknet dataset, should correspond to the labels used in the Darknet
annotation files of the dataset
:return: dictionary mapping index values to corresponding labels text
"""

index_labels = {}
with open(darknet_labels_path, "r") as darknet_labels_file:
index = 0
for line in darknet_labels_file:
if len(line.strip()) > 0:
darknet_label = line.split()[0]
index_labels[index] = darknet_label
index += 1

return index_labels


# ------------------------------------------------------------------------------
def _count_boxes_darknet(
darknet_file_path: str,
Expand Down Expand Up @@ -298,7 +261,7 @@ def filter_class_boxes(
darknet_index_labels = None
if annotation_format == "darknet":
# read the Darknet labels into a dictionary mapping label to label index
darknet_index_labels = _darknet_indices_to_labels(darknet_labels_path)
darknet_index_labels = darknet_indices_to_labels(darknet_labels_path)

# get the set of valid indices, i.e. all Darknet indices
# corresponding to the labels to be included in the filtered dataset
Expand Down
4 changes: 2 additions & 2 deletions src/cvdata/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def _build_write_tfrecord(
# read the image
image_file_name = args["file_ids"][i] + ".jpg"
image_path = os.path.join(args["images_dir"], image_file_name)
image_data = tf.gfile.GFile(image_path, 'rb').read()
image_data = tf.io.gfile.GFile(image_path, 'rb').read()
width, height, _ = image_dimensions(image_path)

# read the semantic segmentation annotation (mask)
mask_path = os.path.join(args["masks_dir"], args["file_ids"][i] + ".png")
seg_data = tf.gfile.GFile(mask_path, 'rb').read()
seg_data = tf.io.gfile.GFile(mask_path, 'rb').read()
seg_width, seg_height, _ = image_dimensions(mask_path)
if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and mask.')
Expand Down
Loading