Skip to content

Commit

Permalink
Merge pull request #229 from Manojkumarmuru/efficientdet
Browse files Browse the repository at this point in the history
Efficientdet
  • Loading branch information
oarriaga authored Oct 24, 2022
2 parents 291ec65 + c159e69 commit feaded9
Show file tree
Hide file tree
Showing 13 changed files with 586 additions and 355 deletions.
417 changes: 287 additions & 130 deletions examples/efficientdet/anchors.py

Large diffs are not rendered by default.

40 changes: 0 additions & 40 deletions examples/efficientdet/boxes.py

This file was deleted.

16 changes: 8 additions & 8 deletions examples/efficientdet/debugger.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import tensorflow as tf
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpus[0], True)


import numpy as np
from efficientdet import EFFICIENTDETD0
from paz.datasets import VOC
from paz.abstract import Processor, SequentialProcessor
import tensorflow as tf
from paz import processors as pr
from paz.abstract import Processor, SequentialProcessor
from paz.datasets import VOC

from detection import AugmentDetection
from efficientdet import EFFICIENTDETD0

# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpus[0], True)


class ShowBoxes(Processor):
Expand Down
129 changes: 77 additions & 52 deletions examples/efficientdet/detection.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
from paz import processors as pr
from paz.abstract import Processor, SequentialProcessor
from paz.pipelines import DetectSingleShot
from paz.pipelines.detection import AugmentBoxes, PreprocessBoxes
from paz.pipelines.image import AugmentImage

import necessary_imports as ni
from draw import (add_box_border, draw_opaque_box, get_text_size,
make_box_transparent, put_text)
from efficientdet_postprocess import process_outputs
from processors import MatchBoxes
from utils import efficientdet_preprocess, get_class_name_efficientdet


class AugmentImage(SequentialProcessor):
"""Augments an RGB image by randomly changing contrast, brightness
saturation and hue.
"""
def __init__(self):
super(AugmentImage, self).__init__()
self.add(pr.RandomContrast())
self.add(pr.RandomBrightness())
self.add(pr.RandomSaturation(0.7))
self.add(pr.RandomHue())


class PreprocessImage(SequentialProcessor):
"""Preprocess RGB image by resizing it to the given ``shape``. If a
``mean`` is given it is substracted from image and it not the image gets
Expand All @@ -35,45 +27,6 @@ def __init__(self, shape, mean=pr.BGR_IMAGENET_MEAN):
self.add(pr.CastImage(float))
self.add(pr.SubtractMeanImage(pr.RGB_IMAGENET_MEAN))
self.add(ni.DivideStandardDeviationImage(ni.RGB_IMAGENET_STDEV))
# if mean is None:
# self.add(pr.NormalizeImage())
# else:
# print('Normal')
# self.add(pr.SubtractMeanImage(mean))


class AugmentBoxes(SequentialProcessor):
"""Perform data augmentation with bounding boxes.
# Arguments
mean: List of three elements used to fill empty image spaces.
"""
def __init__(self, mean=pr.BGR_IMAGENET_MEAN):
super(AugmentBoxes, self).__init__()
self.add(pr.ToImageBoxCoordinates())
self.add(pr.Expand(mean=mean))
# RandomSampleCrop was commented out
self.add(pr.RandomSampleCrop())
self.add(pr.RandomFlipBoxesLeftRight())
self.add(pr.ToNormalizedBoxCoordinates())


class PreprocessBoxes(SequentialProcessor):
"""Preprocess bounding boxes
# Arguments
num_classes: Int.
prior_boxes: Numpy array of shape ``[num_boxes, 4]`` containing
prior/default bounding boxes.
IOU: Float. Intersection over union used to match boxes.
variances: List of two floats indicating variances to be encoded
for encoding bounding boxes.
"""
def __init__(self, num_classes, prior_boxes, IOU, variances):
super(PreprocessBoxes, self).__init__()
self.add(MatchBoxes(prior_boxes, IOU),)
self.add(pr.EncodeBoxes(prior_boxes, variances))
self.add(pr.BoxClassToOneHotVector(num_classes))


class AugmentDetection(SequentialProcessor):
Expand Down Expand Up @@ -152,3 +105,75 @@ def call(self, image):
draw_boxes2D = pr.DrawBoxes2D(get_class_name_efficientdet('VOC'))
image = draw_boxes2D(image.astype('uint8'), outputs)
return self.wrap(image, outputs)


class DetectSingleShot(DetectSingleShot):
"""Single-shot object detection prediction.
# Arguments
model: Keras model.
class_names: List of strings indicating the class names.
score_thresh: Float between [0, 1]
nms_thresh: Float between [0, 1].
mean: List of three elements indicating the per channel mean.
variances: List containing the variances of the encoded boxes.
draw: Boolean. If ``True`` prediction are drawn in the returned image.
"""
def __init__(
self, model, class_names, score_thresh, nms_thresh,
mean=pr.BGR_IMAGENET_MEAN, variances=[0.1, 0.1, 0.2, 0.2],
draw=True):
super().__init__(
model, class_names, score_thresh, nms_thresh,
mean, variances, draw)
self.draw_boxes2D = DrawBoxes2D(class_names)


class DrawBoxes2D(pr.DrawBoxes2D):
"""Draws bounding boxes from Boxes2D messages.
# Arguments
class_names: List of strings.
colors: List of lists containing the color values
weighted: Boolean. If ``True`` the colors are weighted with the
score of the bounding box.
scale: Float. Scale of drawn text.
with_score: Boolean. If ``True`` displays the confidence score.
"""
def __init__(
self, class_names=None, colors=None,
weighted=False, scale=0.7, with_score=True):
super().__init__(
class_names, colors, weighted, scale, with_score)

def compute_prediction_parameters(self, box2D):
x_min, y_min, x_max, y_max = box2D.coordinates
class_name = box2D.class_name
color = self.class_to_color[class_name]
if self.weighted:
color = [int(channel * box2D.score) for channel in color]
if self.with_score:
text = '{} :{}%'.format(class_name, round(box2D.score*100))
if not self.with_score:
text = '{}'.format(class_name)
return x_min, y_min, x_max, y_max, color, text

def call(self, image, boxes2D):
raw_image = image.copy()
for box2D in boxes2D:
prediction_parameters = self.compute_prediction_parameters(box2D)
x_min, y_min, x_max, y_max, color, text = prediction_parameters
draw_opaque_box(image, (x_min, y_min), (x_max, y_max), color)
image = make_box_transparent(raw_image, image)
for box2D in boxes2D:
prediction_parameters = self.compute_prediction_parameters(box2D)
x_min, y_min, x_max, y_max, color, text = prediction_parameters
add_box_border(image, (x_min, y_min), (x_max, y_max), color, 2)
text_size = get_text_size(text, self.scale, 1)
(text_W, text_H), _ = text_size
draw_opaque_box(
image, (x_min+2, y_min+2), (x_min+text_W+5, y_min+text_H+5),
(255, 174, 66))
put_text(
image, text, (x_min+2, y_min + 17), self.scale, (0, 0, 0), 1)
return image
87 changes: 87 additions & 0 deletions examples/efficientdet/draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import cv2

FONT = cv2.FONT_HERSHEY_SIMPLEX
LINE = cv2.LINE_AA


def put_text(image, text, point, scale, color, thickness):
"""Draws text in image.
# Arguments
image: Numpy array.
text: String. Text to be drawn.
point: Tuple of coordinates indicating the top corner of the text.
scale: Float. Scale of text.
color: Tuple of integers. RGB color coordinates.
thickness: Integer. Thickness of the lines used for drawing text.
# Returns
Numpy array with shape ``[H, W, 3]``. Image with text.
"""
return cv2.putText(image, text, point, FONT, scale, color, thickness, LINE)


def get_text_size(text, scale, FONT_THICKNESS, FONT=FONT):
"""Calculates the size of a given text.
# Arguments
text: String. Text whose width and height is to be calculated.
scale: Float. Scale of text.
FONT_THICKNESS: Integer. Thickness of the lines used for drawing text.
FONT: Integer. Style of the text font.
# Returns
Numpy array with shape ``[H, W, 3]``. Image with text.
"""
return cv2.getTextSize(text, FONT, scale, FONT_THICKNESS)


def add_box_border(image, corner_A, corner_B, color, thickness):
""" Draws an open rectangle from ``corner_A`` to ``corner_B``.
# Arguments
image: Numpy array of shape ``[H, W, 3]``.
corner_A: List of length two indicating ``(y, x)`` openCV coordinates.
corner_B: List of length two indicating ``(y, x)`` openCV coordinates.
color: List of length three indicating RGB color of point.
thickness: Integer/openCV Flag. Thickness of rectangle line.
or for filled use cv2.FILLED flag.
# Returns
Numpy array with shape ``[H, W, 3]``. Image with rectangle.
"""
return cv2.rectangle(
image, tuple(corner_A), tuple(corner_B), tuple(color),
thickness)


def draw_opaque_box(image, corner_A, corner_B, color, thickness=-1):
""" Draws a filled rectangle from ``corner_A`` to ``corner_B``.
# Arguments
image: Numpy array of shape ``[H, W, 3]``.
corner_A: List of length two indicating ``(y, x)`` openCV coordinates.
corner_B: List of length two indicating ``(y, x)`` openCV coordinates.
color: List of length three indicating RGB color of point.
thickness: Integer/openCV Flag. Thickness of rectangle line.
or for filled use cv2.FILLED flag.
# Returns
Numpy array with shape ``[H, W, 3]``. Image with rectangle.
"""
return cv2.rectangle(
image, tuple(corner_A), tuple(corner_B), tuple(color),
thickness)


def make_box_transparent(raw_image, image, alpha=0.30):
""" Blends the raw image with bounding box image to add transparency.
# Arguments
raw_image: Numpy array of shape ``[H, W, 3]``.
image: Numpy array of shape ``[H, W, 3]``.
alpha: Float, weightage parameter of weighted sum.
# Returns
Numpy array with shape ``[H, W, 3]``. Image with rectangle.
"""
return cv2.addWeighted(raw_image, 1-alpha, image, alpha, 0.0)
18 changes: 11 additions & 7 deletions examples/efficientdet/efficientdet_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.utils import get_file

from anchors import get_prior_boxes
from anchors import build_prior_boxes
from efficientdet_blocks import BiFPN, BoxNet, ClassNet
from efficientnet_model import EfficientNet
from utils import create_multibox_head

WEIGHT_PATH = (
'/home/manummk95/Desktop/efficientdet_working/required/weights/')
'https://github.com/oarriaga/altamira-data/releases/download/v0.16/')


def EfficientDet(num_classes, base_weights, head_weights, input_shape,
Expand Down Expand Up @@ -82,12 +83,15 @@ def EfficientDet(num_classes, base_weights, head_weights, input_shape,
model = Model(inputs=image, outputs=outputs, name=model_name)

if (((base_weights == 'COCO') and (head_weights == 'COCO')) or
((base_weights == 'COCO') and (head_weights == 'None'))):
weights_path = (WEIGHT_PATH + model_name + '_' +
str(base_weights) + '_' + str(head_weights) + '.h5')
((base_weights == 'COCO') and (head_weights is None))):
model_filename = (model_name + '-' + str(base_weights) + '-' +
str(head_weights) + '_weights.hdf5')
weights_path = get_file(model_filename, WEIGHT_PATH + model_filename,
cache_subdir='paz/models')
print('Loading %s model weights' % weights_path)
model.load_weights(weights_path)

model.prior_boxes = get_prior_boxes(
model.prior_boxes = build_prior_boxes(
min_level, max_level, num_scales, aspect_ratios, anchor_scale,
input_shape[0])
input_shape[0:2])
return model
Loading

0 comments on commit feaded9

Please sign in to comment.