Skip to content

Commit

Permalink
add basic yolov8 segmentation models
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Apr 20, 2024
1 parent 5bbbc79 commit 88eaa71
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 32 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@
- [GPT2](#gpt2)
- [LLaMA2](#llama2)
- [Stable Diffusion](#stable-diffusion)
- [Segment Anything](#segment-anything)
- [Segmentation Models](#segmentation-models)
- [YOLOV8 Segmentation](#yolov8-segmentation)
- [Segment Anything](#segment-anything)
- [Licenses](#licenses)
- [Citing](#citing)

Expand Down Expand Up @@ -1667,7 +1669,18 @@
| Decoder | 49.49M | 1259.5G | [None, 64, 64, 4] | [decoder_v1_5.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/stable_diffusion/decoder_v1_5.h5) |
***

# Segment Anything
# Segmentation Models
## YOLOV8 Segmentation
- [Keras YOLOV8](keras_cv_attention_models/yolov8) includes implementation of [Github ultralytics/ultralytics](https://github.com/ultralytics/ultralytics) segmentation models.

| Model | Params | FLOPs | Input | COCO val mask AP | T4 Inference |
| ------------ | ------ | ------- | ----- | ---------------- | ------------ |
| [YOLOV8_N_SEG](https://github.com/leondgarse/keras_cv_attention_models/releases/download/yolov8/yolov8_n_seg_imagenet.h5) | 3.41M | 6.02G | 640 | 30.5 | |
| [YOLOV8_S_SEG](https://github.com/leondgarse/keras_cv_attention_models/releases/download/yolov8/yolov8_s_seg_imagenet.h5) | 11.82M | 20.08G | 640 | 36.8 | |
| [YOLOV8_M_SEG](https://github.com/leondgarse/keras_cv_attention_models/releases/download/yolov8/yolov8_m_seg_imagenet.h5) | 27.29M | 52.33G | 640 | 40.8 | |
| [YOLOV8_L_SEG](https://github.com/leondgarse/keras_cv_attention_models/releases/download/yolov8/yolov8_l_seg_imagenet.h5) | 46.00M | 105.29G | 640 | 42.6 | |
| [YOLOV8_X_SEG](https://github.com/leondgarse/keras_cv_attention_models/releases/download/yolov8/yolov8_x_seg_imagenet.h5) | 71.83M | 164.30G | 640 | 43.4 | |
## Segment Anything
- [Keras Segment Anything](keras_cv_attention_models/segment_anything) includes implementation of [PDF 2304.02643 Segment Anything](https://arxiv.org/abs/2304.02643).

| Model | Params | FLOPs | Input | COCO val mIoU | T4 Inference |
Expand Down
84 changes: 67 additions & 17 deletions keras_cv_attention_models/coco/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ def __init__(
topk=0, # decode parameter, can be set new value in `self.call`
use_static_output=False, # Set to True if using this as an actual layer, especially for converting tflite
use_sigmoid_on_score=False, # wether applying sigmoid on score outputs. Set True if model is built using `classifier_activation=None`
num_masks=0, # Set > 0 value for segmentation model with masks output
**kwargs,
):
super().__init__(**kwargs)

self.pyramid_levels = list(range(min(pyramid_levels), max(pyramid_levels) + 1))
use_object_scores, num_anchors, anchor_scale = anchors_func.get_anchors_mode_parameters(anchors_mode, use_object_scores, "auto", anchor_scale)
self.regression_len, self.aspect_ratios, self.num_scales = regression_len, aspect_ratios, num_scales
self.regression_len, self.aspect_ratios, self.num_scales, self.num_masks = regression_len, aspect_ratios, num_scales, num_masks
self.anchors_mode, self.use_object_scores, self.anchor_scale = anchors_mode, use_object_scores, anchor_scale # num_anchors not using
if input_shape is not None and (isinstance(input_shape, (list, tuple)) and input_shape[1] is not None):
self.__init_anchor__(input_shape)
Expand Down Expand Up @@ -121,35 +122,65 @@ def __topk_class_boxes_single__(self, pred, topk=5000):
# bboxes, labels, scores = rrs[:, :4], rrs[:, 4], rrs[:, -1]
# return bboxes.numpy(), labels.numpy(), scores.numpy()

def __nms_per_class__(self, bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
@staticmethod
def nms_per_class(bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
# From torchvision.ops.batched_nms strategy: in order to perform NMS independently per class. we add an offset to all the boxes.
# The offset is dependent only on the class idx, and is large enough so that boxes from different classes do not overlap
# Same result with per_class method: https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L409
cls_offset = functional.cast(labels, bbs.dtype) * (functional.reduce_max(bbs) + 1)
bbs_per_class = bbs + functional.expand_dims(cls_offset, -1)
rr, nms_scores = functional.non_max_suppression_with_scores(bbs_per_class, ccs, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
return functional.gather(bbs, rr), functional.gather(labels, rr), nms_scores
indices, nms_scores = functional.non_max_suppression_with_scores(bbs_per_class, ccs, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
return functional.gather(bbs, indices), functional.gather(labels, indices), nms_scores, indices

def __nms_global__(self, bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
rr, nms_scores = functional.non_max_suppression_with_scores(bbs, ccs, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
return functional.gather(bbs, rr), functional.gather(labels, rr), nms_scores
@staticmethod
def nms_global(bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
indices, nms_scores = functional.non_max_suppression_with_scores(bbs, ccs, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
return functional.gather(bbs, indices), functional.gather(labels, indices), nms_scores, indices

def __object_score_split__(self, pred):
return pred[:, :-1], pred[:, -1] # May overwrite

def __to_static__(self, bboxs, lables, confidences, max_output_size=100):
def __to_static__(self, bboxs, lables, confidences, masks=None, max_output_size=100):
indices = functional.expand_dims(functional.range(functional.shape(bboxs)[0]), -1)
lables = functional.cast(lables, bboxs.dtype)
concated = functional.concat([bboxs, functional.expand_dims(lables, -1), functional.expand_dims(confidences, -1)], axis=-1)
if masks is None:
concated = functional.concat([bboxs, functional.expand_dims(lables, -1), functional.expand_dims(confidences, -1)], axis=-1)
else:
masks = functional.reshape(functional.cast(masks, bboxs.dtype), [-1, masks.shape[1] * masks.shape[2]])
concated = functional.concat([bboxs, functional.expand_dims(lables, -1), functional.expand_dims(confidences, -1), masks], axis=-1)
concated = functional.tensor_scatter_nd_update(functional.zeros([max_output_size, concated.shape[-1]], dtype=bboxs.dtype), indices, concated)
return concated

def __decode_single__(self, pred, score_threshold=0.3, iou_or_sigma=0.5, max_output_size=100, method="hard", mode="global", topk=0, input_shape=None):
@staticmethod
def process_mask_proto_single(mask_proto, masks, bboxs):
# mask_proto: [input_height // 4, input_width // 4, 32], masks: [num, 32], bboxs: [num, 4]
protos_height, protos_width = mask_proto.shape[:2]
mask_proto = functional.transpose(functional.reshape(mask_proto, [-1, mask_proto.shape[-1]]), [1, 0])
masks = functional.sigmoid(masks @ mask_proto) # [num, protos_height * protos_width]
masks = functional.reshape(masks, [-1, protos_height, protos_width]) # [num, protos_height, protos_width]

""" Filter by bbox area """
top, left, bottom, right = functional.split(bboxs[:, :, None], [1, 1, 1, 1], axis=1) # [num, 1_pos, 1]
height_range = functional.range(protos_height, dtype=top.dtype)[None, :, None] / protos_height # [1, protos_height, 1]
width_range = functional.range(protos_width, dtype=top.dtype)[None, None] / protos_width # [1, 1, protos_width]
height_cond = functional.logical_and(height_range >= top, height_range < bottom) # [num, protos_height, 1]
width_cond = functional.logical_and(width_range >= left, width_range < right) # [num, 1, protos_width]
masks *= functional.cast(functional.logical_and(height_cond, width_cond), masks.dtype) # [num, protos_height, protos_width]
return masks

def __decode_single__(
self, pred, mask_proto=None, score_threshold=0.3, iou_or_sigma=0.5, max_output_size=100, method="hard", mode="global", topk=0, input_shape=None
):
# https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L159
pred = functional.cast(pred.detach() if hasattr(pred, "detach") else pred, "float32")
if input_shape is not None:
self.__init_anchor__(input_shape)

if self.num_masks > 0: # Segmentation masks
pred, masks = pred[:, :-self.num_masks], pred[:, -self.num_masks:]
else:
masks = None

if self.use_object_scores: # YOLO outputs: [bboxes, classses_score, object_score]
pred, object_scores = self.__object_score_split__(pred)

Expand All @@ -171,18 +202,28 @@ def __decode_single__(self, pred, score_threshold=0.3, iou_or_sigma=0.5, max_out
iou_threshold, soft_nms_sigma = (1.0, iou_or_sigma / 2) if method.lower() == "gaussian" else (iou_or_sigma, 0.0)

if mode == "per_class":
bboxs, lables, confidences = self.__nms_per_class__(bbs_decoded, ccs, labels, score_threshold, iou_threshold, soft_nms_sigma, max_output_size)
bboxs, lables, confidences, indices = self.nms_per_class(bbs_decoded, ccs, labels, score_threshold, iou_threshold, soft_nms_sigma, max_output_size)
elif mode == "global":
bboxs, lables, confidences = self.__nms_global__(bbs_decoded, ccs, labels, score_threshold, iou_threshold, soft_nms_sigma, max_output_size)
bboxs, lables, confidences, indices = self.nms_global(bbs_decoded, ccs, labels, score_threshold, iou_threshold, soft_nms_sigma, max_output_size)
else:
bboxs, lables, confidences = bbs_decoded, labels, ccs # Return raw decoded data for testing
bboxs, lables, confidences, indices = bbs_decoded, labels, ccs, None # Return raw decoded data for testing

return self.__to_static__(bboxs, lables, confidences, max_output_size) if self.use_static_output else (bboxs, lables, confidences)
if self.num_masks > 0 and indices is not None: # Segmentation masks
masks = functional.gather(masks, indices)
masks = self.process_mask_proto_single(mask_proto, masks, bboxs)

def call(self, preds, input_shape=None, training=False, **nms_kwargs):
if self.use_static_output:
return self.__to_static__(bboxs, lables, confidences, masks, max_output_size)
elif self.num_masks > 0:
return bboxs, lables, confidences, masks
else:
return bboxs, lables, confidences

def call(self, preds, mask_protos=None, input_shape=None, training=False, **nms_kwargs):
"""
https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L159
mask_protos: mask output from segmentation model.
input_shape: actual input shape if model using dynamic input shape `[None, None, 3]`.
nms_kwargs:
score_threshold: float value in (0, 1), min score threshold, lower output score will be excluded. Default 0.3.
Expand All @@ -194,12 +235,19 @@ def call(self, preds, input_shape=None, training=False, **nms_kwargs):
topk: Using topk highest scores, each bbox may have multi labels. Set `0` to disable, `-1` using all. Default 0.
"""
self.nms_kwargs.update(nms_kwargs)
if self.use_static_output:
if self.num_masks > 0: # Segmentation model
assert mask_protos is not None, "self.num_masks={} > 0, but mask_protos not provided".format(self.num_masks)

if self.use_static_output and self.num_masks > 0: # Segmentation model
return functional.map_fn(lambda xx: self.__decode_single__(xx[0], xx[1], **nms_kwargs), [preds, mask_protos], fn_output_signature=preds.dtype)
elif self.use_static_output:
return functional.map_fn(lambda xx: self.__decode_single__(xx, **nms_kwargs), preds)
elif len(preds.shape) == 3 and self.num_masks > 0: # Segmentation model
return [self.__decode_single__(pred, mask_proto, **self.nms_kwargs, input_shape=input_shape) for pred, mask_proto in zip(preds, mask_protos)]
elif len(preds.shape) == 3:
return [self.__decode_single__(pred, **self.nms_kwargs, input_shape=input_shape) for pred in preds]
else:
return self.__decode_single__(preds, **self.nms_kwargs, input_shape=input_shape)
return self.__decode_single__(preds, mask_protos, **self.nms_kwargs, input_shape=input_shape)

def get_config(self):
config = super().get_config()
Expand All @@ -213,6 +261,8 @@ def get_config(self):
"aspect_ratios": self.aspect_ratios,
"num_scales": self.num_scales,
"use_static_output": self.use_static_output,
"use_sigmoid_on_score": self.use_sigmoid_on_score,
"num_masks": self.num_masks,
}
)
config.update(self.nms_kwargs)
Expand Down
19 changes: 17 additions & 2 deletions keras_cv_attention_models/plot_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,27 @@ def draw_bboxes(bboxes, ax=None):


def show_image_with_bboxes(
image, bboxes, labels=None, confidences=None, is_bbox_width_first=False, ax=None, label_font_size=8, num_classes=80, indices_2_labels=None
image, bboxes=None, labels=None, confidences=None, masks=None, is_bbox_width_first=False, ax=None, label_font_size=8, num_classes=80, indices_2_labels=None
):
import matplotlib.pyplot as plt
import numpy as np
from keras_cv_attention_models.coco import info
from keras_cv_attention_models.backend import numpy_image_resize

need_plt_show = False
if ax is None:
fig, ax = plt.subplots()
need_plt_show = True

ax.imshow(image)
bboxes = np.array(bboxes)
masks = [] if masks is None else np.array(masks)
for mask in masks: # Show segmentation results
random_color = np.concatenate([np.random.random(3), np.array([0.5])], axis=0)[None, None]
mask = np.greater(mask, 0.5)
colored_mask = numpy_image_resize(mask[:, :, None] * random_color, image.shape[:2])
ax.imshow(colored_mask)

bboxes = np.zeros([0, 4]) if bboxes is None else np.array(bboxes)
if is_bbox_width_first:
bboxes = bboxes[:, [1, 0, 3, 2]]
for id, bb in enumerate(bboxes):
Expand Down Expand Up @@ -127,6 +136,12 @@ def show_image_with_bboxes(
return ax


def show_image_with_bboxes_and_masks(
image, bboxes=None, labels=None, confidences=None, masks=None, is_bbox_width_first=False, ax=None, label_font_size=8, num_classes=80, indices_2_labels=None
):
return show_image_with_bboxes(**locals())


""" Show clip results """


Expand Down
8 changes: 8 additions & 0 deletions keras_cv_attention_models/pytorch_backend/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def log(inputs, name=None):
return wrapper(partial(torch.log), inputs, name=name)


def logical_and(xx, yy, name=None):
return wrapper(lambda inputs: torch.logical_and(inputs[0], inputs[1]), [xx, yy], name=name)


def logical_or(xx, yy, name=None):
return wrapper(lambda inputs: torch.logical_or(inputs[0], inputs[1]), [xx, yy], name=name)


def matmul(xx, yy, transpose_a=False, transpose_b=False, name=None):
return wrapper(lambda inputs: torch.matmul(inputs[0].T if transpose_a else inputs[0], inputs[1].T if transpose_b else inputs[1]), [xx, yy], name=name)

Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/pytorch_backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def forward(self, inputs, **kwargs):
return [intra_nodes[ii][0] for ii in self.output_names] if self.num_outputs != 1 else intra_nodes[self.output_names[0]][0]

@torch.no_grad()
def predict(self, inputs, **kwargs):
def predict(self, inputs, verbose=0, **kwargs):
return self.forward(inputs, **kwargs)

def graphnode_forward(self, inputs):
Expand Down
Loading

1 comment on commit 88eaa71

@leondgarse
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.