Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal hand detection #213

Merged
merged 19 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions docs/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@
'functions': [
models.detection.SSD300,
models.detection.SSD512,
models.detection.HaarCascadeDetector
models.detection.HaarCascadeDetector,
models.detection.SSD512Custom
],
},

Expand All @@ -268,7 +269,7 @@
models.Projector,
models.DetNet,
models.IKNet,

],
},

Expand Down Expand Up @@ -580,7 +581,8 @@
pipelines.PreprocessBoxes,
pipelines.PostprocessBoxes2D,
pipelines.DetectSingleShot,
pipelines.DetectHaarCascade
pipelines.DetectHaarCascade,
pipelines.SSD512HandDetection
]
},

Expand All @@ -604,7 +606,8 @@
pipelines.TransformKeypoints,
pipelines.HigherHRNetHumanPose2D,
pipelines.DetNetHandKeypoints,
pipelines.MinimalHandPoseEstimation
pipelines.MinimalHandPoseEstimation,
pipelines.DetectMinimalHand
]
},

Expand Down Expand Up @@ -659,7 +662,8 @@
pipelines.PIX2POSEPowerDrill,
pipelines.PIX2YCBTools6D,
pipelines.DetNetHandKeypoints,
pipelines.MinimalHandPoseEstimation
pipelines.MinimalHandPoseEstimation,
pipelines.DetectMinimalHand
]
},

Expand Down
11 changes: 11 additions & 0 deletions examples/hand_detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Hand detection example

To test the live hand detection from camera, run:
```py
python demo.py
```

To test the hand detection with pose estimation and hand closure classification, run:
```py
python demo_image.py
```
14 changes: 14 additions & 0 deletions examples/hand_detection/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import argparse
from paz.pipelines import SSD512HandDetection
from paz.backend.camera import VideoPlayer, Camera


parser = argparse.ArgumentParser(description='Minimal hand detection')
parser.add_argument('-c', '--camera_id', type=int, default=0,
help='Camera device ID')
args = parser.parse_args()

pipeline = SSD512HandDetection()
camera = Camera(args.camera_id)
player = VideoPlayer((640, 480), pipeline, camera)
player.run()
17 changes: 17 additions & 0 deletions examples/hand_detection/pose_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import argparse
from paz.applications import DetectMinimalHand
from paz.applications import MinimalHandPoseEstimation
from paz.pipelines.detection import SSD512HandDetection
from paz.backend.camera import VideoPlayer, Camera


parser = argparse.ArgumentParser(description='Minimal hand detection')
parser.add_argument('-c', '--camera_id', type=int, default=0,
help='Camera device ID')
args = parser.parse_args()

pipeline = DetectMinimalHand(
SSD512HandDetection(), MinimalHandPoseEstimation(right_hand=False))
camera = Camera(args.camera_id)
player = VideoPlayer((640, 480), pipeline, camera)
player.run()
1 change: 1 addition & 0 deletions paz/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .pipelines import HigherHRNetHumanPose2D
from .pipelines import DetNetHandKeypoints
from .pipelines import MinimalHandPoseEstimation
from .pipelines import DetectMinimalHand
4 changes: 2 additions & 2 deletions paz/backend/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def mask_classes(boxes, positive_mask, ignoring_mask):
return boxes


def match(boxes, prior_boxes, positive_iou=0.5, negative_iou=0.0):
def match_beta(boxes, prior_boxes, positive_iou=0.5, negative_iou=0.0):
"""Matches each prior box with a ground truth box (box from `boxes`).
It then selects which matched box will be considered positive e.g. iou > .5
and returns for each prior box a ground truth box that is either positive
Expand Down Expand Up @@ -177,7 +177,7 @@ def match(boxes, prior_boxes, positive_iou=0.5, negative_iou=0.0):
return matched_boxes


def match2(boxes, prior_boxes, iou_threshold=0.5):
def match(boxes, prior_boxes, iou_threshold=0.5):
"""Matches each prior box with a ground truth box (box from `boxes`).
It then selects which matched box will be considered positive e.g. iou > .5
and returns for each prior box a ground truth box that is either positive
Expand Down
1 change: 0 additions & 1 deletion paz/datasets/open_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def load_data(self):
# skip header
annotations_file.readline()

# for line in tqdm(annotations_file, total=num_lines):
for line in annotations_file:
row = line.split(",")

Expand Down
1 change: 1 addition & 0 deletions paz/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .detection import SSD300
from .detection import SSD512
from .detection import SSD512Custom
from .detection import HaarCascadeDetector
from .keypoint.projector import Projector
from .keypoint.keypointnet import KeypointNet
Expand Down
1 change: 1 addition & 0 deletions paz/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .ssd300 import SSD300
from .ssd512 import SSD512
from .ssd512_custom import SSD512Custom
from .haar_cascade import HaarCascadeDetector
40 changes: 40 additions & 0 deletions paz/models/detection/ssd512_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from tensorflow.keras import Model
from tensorflow.keras.utils import get_file
from .ssd512 import SSD512
from paz.models.detection.utils import create_multibox_head


def SSD512Custom(num_classes, weight_path, num_priors=[4, 6, 6, 6, 6, 4, 4],
l2_loss=5e-4, trainable_base=False):
"""Custom Single-shot-multibox detector for 512x512x3 BGR input images.
# Arguments
num_classes: Integer. Specifies the number of class labels.
weight_path: String. Weight path trained on custom dataset.
num_priors: List of integers. Number of default box shapes
used in each detection layer.
l2_loss: Float. l2 regularization loss for convolutional layers.
trainable_base: Boolean. If `True` the base model
weights are also trained.

# Reference
- [SSD: Single Shot MultiBox
Detector](https://arxiv.org/abs/1512.02325)
"""
base_model = SSD512(weights='COCO', trainable_base=trainable_base)
branch_names = ['branch_1', 'branch_2', 'branch_3', 'branch_4',
'branch_5', 'branch_6', 'branch_7']
branch_tensors = []
for branch_name in branch_names:
branch_layer = base_model.get_layer(branch_name)
branch_tensors.append(branch_layer.output)

output_tensor = create_multibox_head(
branch_tensors, num_classes, num_priors, l2_loss)
model = Model(base_model.input, output_tensor, name='SSD512Custom')
model.prior_boxes = base_model.prior_boxes
filename = os.path.basename(weight_path)
weights_path = get_file(filename, weight_path, cache_subdir='paz/models')
print('==> Loading %s model weights' % weights_path)
model.load_weights(weights_path)
return model
2 changes: 2 additions & 0 deletions paz/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .detection import DetectMiniXceptionFER
from .detection import DetectKeypoints2D
from .detection import DetectFaceKeypointNet2D32
from .detection import SSD512HandDetection

from .keypoints import KeypointNetSharedAugmentation
from .keypoints import KeypointNetInference
Expand All @@ -29,6 +30,7 @@
from .keypoints import HigherHRNetHumanPose2D
from .keypoints import DetNetHandKeypoints
from .keypoints import MinimalHandPoseEstimation
from .keypoints import DetectMinimalHand

from .renderer import RenderTwoViews
from .renderer import RandomizeRenderedImage
Expand Down
45 changes: 43 additions & 2 deletions paz/pipelines/detection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np

from paz.models.detection.haar_cascade import WEIGHT_PATH

from .. import processors as pr
from ..abstract import SequentialProcessor, Processor
from ..models import SSD512, SSD300, HaarCascadeDetector
from ..models import SSD512, SSD300, HaarCascadeDetector, SSD512Custom
from ..datasets import get_class_names

from .image import AugmentImage, PreprocessImage
Expand Down Expand Up @@ -65,8 +67,9 @@ def __init__(self, prior_boxes, split=pr.TRAIN, num_classes=21, size=300,
super(AugmentDetection, self).__init__()
# image processors
self.augment_image = AugmentImage()
self.augment_image.add(pr.ConvertColorSpace(pr.RGB2BGR))
# self.augment_image.add(pr.ConvertColorSpace(pr.RGB2BGR))
self.preprocess_image = PreprocessImage((size, size), mean)
self.preprocess_image.insert(0, pr.ConvertColorSpace(pr.RGB2BGR))

# box processors
self.augment_boxes = AugmentBoxes()
Expand Down Expand Up @@ -477,3 +480,41 @@ def __init__(self, offsets=[0, 0], radius=3):
estimate_keypoints = FaceKeypointNet2D32(draw=False)
super(DetectFaceKeypointNet2D32, self).__init__(
detect, estimate_keypoints, offsets, radius)


class SSD512HandDetection(DetectSingleShot):
"""Minimal hand detection with SSD512Custom trained on OPenImageV6.

# Arguments
score_thresh: Float between [0, 1]
nms_thresh: Float between [0, 1].
draw: Boolean. If ``True`` prediction are drawn in the returned image.

# Example
``` python
from paz.pipelines import SSD512HandDetection

detect = SSD512HandDetection()

# apply directly to an image (numpy-array)
inferences = detect(image)
```
# Returns
A function that takes an RGB image and outputs the predictions
as a dictionary with ``keys``: ``image`` and ``boxes2D``.
The corresponding values of these keys contain the image with the drawn
inferences and a list of ``paz.abstract.messages.Boxes2D``.

# Reference
- [SSD: Single Shot MultiBox
Detector](https://arxiv.org/abs/1512.02325)
"""
def __init__(self, score_thresh=0.40, nms_thresh=0.45, draw=True):
weight_path = (
'https://github.com/oarriaga/altamira-data/releases/'
'download/v0.15/SSD512_OpenImageV6_trainable_weights.hdf5')
class_names = ['background', 'hand']
num_classes = len(class_names)
model = SSD512Custom(num_classes, weight_path)
super(SSD512HandDetection, self).__init__(
model, class_names, score_thresh, nms_thresh, draw=draw)
56 changes: 55 additions & 1 deletion paz/pipelines/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .angles import IKNetHandJointAngles


from ..backend.image import get_affine_transform, flip_left_right
from ..backend.image import get_affine_transform, flip_left_right, lincolor
from ..backend.keypoints import flip_keypoints_left_right, uv_to_vu
from ..datasets import JOINT_CONFIG, FLIP_CONFIG

Expand Down Expand Up @@ -328,3 +328,57 @@ def call(self, image):
return self.wrap(keypoints['image'], keypoints['keypoints3D'],
keypoints['keypoints2D'], angles['absolute_angles'],
angles['relative_angles'])


class DetectMinimalHand(pr.Processor):
def __init__(self, detect, estimate_keypoints, offsets=[0, 0], radius=3):
"""Minimal hand detection and keypoint estimator pipeline.

# Arguments
detect: Function for detecting objects. The output should be a
dictionary with key ``Boxes2D`` containing a list
of ``Boxes2D`` messages.
estimate_keypoints: Function for estimating keypoints. The output
should be a dictionary with key ``keypoints`` containing
a numpy array of keypoints.
offsets: List of two elements. Each element must be between [0, 1].
radius: Int indicating the radius of the keypoints to be drawn.
"""
super(DetectMinimalHand, self).__init__()
self.class_names = ['OPEN', 'CLOSE']
self.colors = lincolor(len(self.class_names))
self.detect = detect
self.estimate_keypoints = estimate_keypoints
self.classify_hand_closure = pr.SequentialProcessor(
[pr.IsHandOpen(), pr.BooleanToTextMessage('OPEN', 'CLOSE')])
self.square = pr.SequentialProcessor()
self.square.add(pr.SquareBoxes2D())
self.square.add(pr.OffsetBoxes2D(offsets))
self.clip = pr.ClipBoxes2D()
self.crop = pr.CropBoxes2D()
self.change_coordinates = pr.ChangeKeypointsCoordinateSystem()
self.draw = pr.DrawHandSkeleton(keypoint_radius=radius)
self.draw_boxes = pr.DrawBoxes2D(self.class_names, self.colors,
with_score=False)
self.wrap = pr.WrapOutput(
['image', 'boxes2D', 'keypoints2D', 'keypoints3D'])

def call(self, image):
boxes2D = self.detect(image.copy())['boxes2D']
boxes2D = self.square(boxes2D)
boxes2D = self.clip(image, boxes2D)
cropped_images = self.crop(image, boxes2D)
keypoints2D = []
keypoints3D = []
for cropped_image, box2D in zip(cropped_images, boxes2D):
inference = self.estimate_keypoints(cropped_image)
keypoints = self.change_coordinates(
inference['keypoints2D'], box2D)
hand_closure_status = self.classify_hand_closure(
inference['relative_angles'])
box2D.class_name = hand_closure_status
keypoints2D.append(keypoints)
keypoints3D.append(inference['keypoints3D'])
image = self.draw(image, keypoints)
image = self.draw_boxes(image, boxes2D)
return self.wrap(image, boxes2D, keypoints2D, keypoints3D)
4 changes: 2 additions & 2 deletions tests/paz/backend/boxes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def boxes_with_label():

@pytest.fixture
def target_unique_matches():
# return np.array([[238., 155., 306., 204.]])
return np.array([[47.0, 239.0, 194.0, 370.0]])
return np.array([[47.0, 239.0, 194.0, 370.0],
[238., 155., 306., 204.]])


@pytest.fixture
Expand Down