Skip to content

Commit

Permalink
Merge pull request #214 from proneetsharma/master
Browse files Browse the repository at this point in the history
Update SSD512
  • Loading branch information
oarriaga authored Aug 4, 2022
2 parents 83113e5 + 744b616 commit 62e4632
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 76 deletions.
3 changes: 1 addition & 2 deletions docs/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@
'functions': [
models.detection.SSD300,
models.detection.SSD512,
models.detection.HaarCascadeDetector,
models.detection.SSD512Custom
models.detection.HaarCascadeDetector
],
},

Expand Down
1 change: 0 additions & 1 deletion paz/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .detection import SSD300
from .detection import SSD512
from .detection import SSD512Custom
from .detection import HaarCascadeDetector
from .keypoint.projector import Projector
from .keypoint.keypointnet import KeypointNet
Expand Down
1 change: 0 additions & 1 deletion paz/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .ssd300 import SSD300
from .ssd512 import SSD512
from .ssd512_custom import SSD512Custom
from .haar_cascade import HaarCascadeDetector
57 changes: 31 additions & 26 deletions paz/models/detection/ssd512.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
from .utils import create_multibox_head
from .utils import create_prior_boxes

WEIGHT_PATH = ('https://github.com/oarriaga/altamira-data/'
'releases/download/v0.1/')

BASE_WEIGHT_PATH = ('https://github.com/oarriaga/altamira-data/'
'releases/download/v0.1/')


def SSD512(num_classes=81, weights='COCO', input_shape=(512, 512, 3),
num_priors=[4, 6, 6, 6, 6, 4, 4], l2_loss=0.0005,
return_base=False, trainable_base=True):
def SSD512(num_classes=81, base_weights='COCO', head_weights='COCO',
input_shape=(512, 512, 3), num_priors=[4, 6, 6, 6, 6, 4, 4],
l2_loss=0.0005, return_base=False, trainable_base=True):
"""Single-shot-multibox detector for 512x512x3 BGR input images.
# Arguments
num_classes: Integer. Specifies the number of class labels.
weights: String or None. If string should be a valid dataset name.
Current valid datasets include `COCO` and `YCBVideo`.
base_weights: String or None. If string should be a valid dataset name.
Current valid datasets include `COCO` and `OIV6Hand`.
head_weights: String or None. If string should be a valid dataset name.
Current valid datasets include `COCO`, `YCBVideo` and `OIV6Hand`.
input_shape: List of integers. Input shape to the model including only
spatial and channel resolution e.g. (512, 512, 3).
num_priors: List of integers. Number of default box shapes
Expand All @@ -39,20 +40,23 @@ def SSD512(num_classes=81, weights='COCO', input_shape=(512, 512, 3),
Detector](https://arxiv.org/abs/1512.02325)
"""

datasets = {'COCO', 'YCBVideo', None}
if not (weights in datasets or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `COCO`, '
'YCBVideo or the path to the weights '
'file to be loaded.')
if base_weights not in ['COCO', 'OIV6Hand']:
raise ValueError('Invalid `base_weights`:', base_weights)

if head_weights not in ['COCO', 'YCBVideo', 'OIV6Hand']:
raise ValueError('Invalid `head_weights`:', head_weights)

if ((base_weights == 'OIV6Hand') and (head_weights != 'OIV6Hand')):
raise NotImplementedError('Invalid `base_weights` with head_weights')

if weights == 'COCO' and num_classes != 81:
raise ValueError('If using `weights` as `"COCO"` '
'`num_classes` should be 81')
if ((num_classes != 81) and (head_weights == 'COCO')):
raise ValueError('Invalid `head_weights` with given `num_classes`')

if weights == 'YCBVideo' and num_classes != 22:
raise ValueError('If using `weights` as `"YCBVideo"` '
'`num_classes` should be 22')
if ((num_classes != 22) and (head_weights == 'YCBVideo')):
raise ValueError('Invalid `head_weights` with given `num_classes`')

if ((num_classes != 2) and (head_weights == 'OIV6Hand')):
raise ValueError('Invalid `head_weights` with given `num_classes`')

image = Input(shape=input_shape, name='image')

Expand Down Expand Up @@ -200,13 +204,14 @@ def SSD512(num_classes=81, weights='COCO', input_shape=(512, 512, 3),

model = Model(inputs=image, outputs=output_tensor, name='SSD512')

if weights is not None:
model_name = '_'.join(['SSD512', weights])

if weights is not None:
weights_url = BASE_WEIGHT_PATH + model_name + '_weights.hdf5'
weights_path = get_file(os.path.basename(weights_url), weights_url,
if ((base_weights is not None) or (head_weights is not None)):
model_filename = [str(base_weights), str(head_weights)]
model_filename = '_'.join(['SSD512', '-'.join(model_filename),
'weights.hdf5'])
weights_path = get_file(model_filename, WEIGHT_PATH + model_filename,
cache_subdir='paz/models')
print('Loading %s model weights' % weights_path)

model.load_weights(weights_path)

model.prior_boxes = create_prior_boxes('COCO')
Expand Down
40 changes: 0 additions & 40 deletions paz/models/detection/ssd512_custom.py

This file was deleted.

10 changes: 4 additions & 6 deletions paz/pipelines/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .. import processors as pr
from ..abstract import SequentialProcessor, Processor
from ..models import SSD512, SSD300, HaarCascadeDetector, SSD512Custom
from ..models import SSD512, SSD300, HaarCascadeDetector
from ..datasets import get_class_names

from .image import AugmentImage, PreprocessImage
Expand Down Expand Up @@ -213,7 +213,7 @@ class SSD512YCBVideo(DetectSingleShot):
"""
def __init__(self, score_thresh=0.60, nms_thresh=0.45, draw=True):
names = get_class_names('YCBVideo')
model = SSD512(weights='YCBVideo', num_classes=len(names))
model = SSD512(head_weights='YCBVideo', num_classes=len(names))
super(SSD512YCBVideo, self).__init__(
model, names, score_thresh, nms_thresh, draw=draw)

Expand Down Expand Up @@ -510,11 +510,9 @@ class SSD512HandDetection(DetectSingleShot):
Detector](https://arxiv.org/abs/1512.02325)
"""
def __init__(self, score_thresh=0.40, nms_thresh=0.45, draw=True):
weight_path = (
'https://github.com/oarriaga/altamira-data/releases/'
'download/v0.15/SSD512_OpenImageV6_trainable_weights.hdf5')
class_names = ['background', 'hand']
num_classes = len(class_names)
model = SSD512Custom(num_classes, weight_path)
model = SSD512(num_classes, base_weights='OIV6Hand',
head_weights='OIV6Hand')
super(SSD512HandDetection, self).__init__(
model, class_names, score_thresh, nms_thresh, draw=draw)

0 comments on commit 62e4632

Please sign in to comment.