Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Mask to use coco-style segmentation by default #888

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added demo/demo_data/prediction_visual.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
204 changes: 121 additions & 83 deletions demo/inference_for_mmdetection.ipynb

Large diffs are not rendered by default.

418 changes: 314 additions & 104 deletions demo/inference_for_yolov8.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

168 changes: 57 additions & 111 deletions sahi/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@

from sahi.utils.coco import CocoAnnotation, CocoPrediction
from sahi.utils.cv import (
get_bbox_from_bool_mask,
get_bbox_from_coco_segmentation,
get_bool_mask_from_coco_segmentation,
get_coco_segmentation_from_bool_mask,
)
from sahi.utils.shapely import ShapelyAnnotation

try:
from pycocotools import mask as mask_utils

use_rle = True
except ImportError:
use_rle = False


class BoundingBox:
"""
Expand Down Expand Up @@ -156,14 +149,13 @@ def from_float_mask(
"""
bool_mask = mask > mask_threshold
return cls(
bool_mask=bool_mask,
segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
shift_amount=shift_amount,
full_shape=full_shape,
)

@classmethod
def from_coco_segmentation(
cls,
def __init__(
self,
segmentation,
full_shape=None,
shift_amount: list = [0, 0],
Expand All @@ -187,15 +179,22 @@ def from_coco_segmentation(
# confirm full_shape is given
if full_shape is None:
raise ValueError("full_shape must be provided")
bool_mask = get_bool_mask_from_coco_segmentation(segmentation, height=full_shape[0], width=full_shape[1])
return cls(
bool_mask=bool_mask,
shift_amount=shift_amount,
full_shape=full_shape,
)

def __init__(
self,
self.shift_x = shift_amount[0]
self.shift_y = shift_amount[1]

if full_shape:
self.full_shape_height = full_shape[0]
self.full_shape_width = full_shape[1]
else:
self.full_shape_height = None
self.full_shape_width = None

self.segmentation = segmentation

@classmethod
def from_bool_mask(
cls,
bool_mask=None,
full_shape=None,
shift_amount: list = [0, 0],
Expand All @@ -210,45 +209,17 @@ def __init__(
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""

if len(bool_mask) > 0:
has_bool_mask = True
else:
has_bool_mask = False

if has_bool_mask:
self._mask = self.encode_bool_mask(bool_mask)
else:
self._mask = None

self.shift_x = shift_amount[0]
self.shift_y = shift_amount[1]

if full_shape:
self.full_shape_height = full_shape[0]
self.full_shape_width = full_shape[1]
elif has_bool_mask:
self.full_shape_height = self.bool_mask.shape[0]
self.full_shape_width = self.bool_mask.shape[1]
else:
self.full_shape_height = None
self.full_shape_width = None

def encode_bool_mask(self, bool_mask):
_mask = bool_mask
if use_rle:
_mask = mask_utils.encode(np.asfortranarray(bool_mask.astype(np.uint8)))
return _mask

def decode_bool_mask(self, bool_mask):
_mask = bool_mask
if use_rle:
_mask = mask_utils.decode(bool_mask).astype(bool)
return _mask
return cls(
segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
shift_amount=shift_amount,
full_shape=full_shape,
)

@property
def bool_mask(self):
return self.decode_bool_mask(self._mask)
return get_bool_mask_from_coco_segmentation(
self.segmentation, width=self.full_shape[1], height=self.full_shape[0]
)

@property
def shape(self):
Expand All @@ -275,46 +246,17 @@ def get_shifted_mask(self):
# Confirm full_shape is specified
if (self.full_shape_height is None) or (self.full_shape_width is None):
raise ValueError("full_shape is None")
# init full mask
mask_fullsized = np.full(
(
self.full_shape_height,
self.full_shape_width,
),
0,
dtype="float32",
)

# arrange starting ending indexes
starting_pixel = [self.shift_x, self.shift_y]
ending_pixel = [
min(starting_pixel[0] + self.bool_mask.shape[1], self.full_shape_width),
min(starting_pixel[1] + self.bool_mask.shape[0], self.full_shape_height),
]

# convert sliced mask to full mask
mask_fullsized[starting_pixel[1] : ending_pixel[1], starting_pixel[0] : ending_pixel[0]] = self.bool_mask[
: ending_pixel[1] - starting_pixel[1], : ending_pixel[0] - starting_pixel[0]
]

shifted_segmentation = []
for s in self.segmentation:
xs = [min(self.shift_x + s[i], self.full_shape_width) for i in range(0, len(s) - 1, 2)]
ys = [min(self.shift_y + s[i], self.full_shape_height) for i in range(1, len(s), 2)]
shifted_segmentation.append([j for i in zip(xs, ys) for j in i])
return Mask(
mask_fullsized,
segmentation=shifted_segmentation,
shift_amount=[0, 0],
full_shape=self.full_shape,
)

def to_coco_segmentation(self):
"""
Returns boolean mask as coco segmentation:
[
[x1, y1, x2, y2, x3, y3, ...],
[x1, y1, x2, y2, x3, y3, ...],
...
]
"""
coco_segmentation = get_coco_segmentation_from_bool_mask(self.bool_mask)
return coco_segmentation


class ObjectAnnotation:
"""
Expand Down Expand Up @@ -346,9 +288,10 @@ def from_bool_mask(
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
return cls(
category_id=category_id,
bool_mask=bool_mask,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
Expand Down Expand Up @@ -388,10 +331,9 @@ def from_coco_segmentation(
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
bool_mask = get_bool_mask_from_coco_segmentation(segmentation, width=full_shape[1], height=full_shape[0])
return cls(
category_id=category_id,
bool_mask=bool_mask,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
Expand Down Expand Up @@ -539,7 +481,7 @@ def from_imantics_annotation(
def __init__(
self,
bbox: Optional[List[int]] = None,
bool_mask: Optional[np.ndarray] = None,
segmentation: Optional[np.ndarray] = None,
category_id: Optional[int] = None,
category_name: Optional[str] = None,
shift_amount: Optional[List[int]] = [0, 0],
Expand All @@ -549,8 +491,12 @@ def __init__(
Args:
bbox: List
[minx, miny, maxx, maxy]
bool_mask: np.ndarray with bool elements
2D mask of object, should have a shape of height*width
segmentation: List[List]
[
[x1, y1, x2, y2, x3, y3, ...],
[x1, y1, x2, y2, x3, y3, ...],
...
]
category_id: int
ID of the object category
category_name: str
Expand All @@ -564,21 +510,20 @@ def __init__(
"""
if not isinstance(category_id, int):
raise ValueError("category_id must be an integer")
if (bbox is None) and (bool_mask is None):
raise ValueError("you must provide a bbox or bool_mask")

if bool_mask is not None:
if (bbox is None) and (segmentation is None):
raise ValueError("you must provide a bbox or segmentation")
if segmentation is not None:
self.mask = Mask(
bool_mask=bool_mask,
segmentation=segmentation,
shift_amount=shift_amount,
full_shape=full_shape,
)
bbox_from_bool_mask = get_bbox_from_bool_mask(bool_mask)
bbox_from_segmentation = get_bbox_from_coco_segmentation(segmentation)
# https://github.com/obss/sahi/issues/235
if bbox_from_bool_mask is not None:
bbox = bbox_from_bool_mask
if bbox_from_segmentation is not None:
bbox = bbox_from_segmentation
else:
raise ValueError("Invalid boolean mask.")
raise ValueError("Invalid segmentation mask.")
else:
self.mask = None

Expand Down Expand Up @@ -613,7 +558,7 @@ def to_coco_annotation(self):
"""
if self.mask:
coco_annotation = CocoAnnotation.from_coco_segmentation(
segmentation=self.mask.to_coco_segmentation(),
segmentation=self.mask.segmentation(),
category_id=self.category.id,
category_name=self.category.name,
)
Expand All @@ -631,7 +576,7 @@ def to_coco_prediction(self):
"""
if self.mask:
coco_prediction = CocoPrediction.from_coco_segmentation(
segmentation=self.mask.to_coco_segmentation(),
segmentation=self.mask.segmentation(),
category_id=self.category.id,
category_name=self.category.name,
score=1,
Expand All @@ -651,7 +596,7 @@ def to_shapely_annotation(self):
"""
if self.mask:
shapely_annotation = ShapelyAnnotation.from_coco_segmentation(
segmentation=self.mask.to_coco_segmentation(),
segmentation=self.mask.segmentation(),
)
else:
shapely_annotation = ShapelyAnnotation.from_coco_bbox(
Expand Down Expand Up @@ -695,13 +640,14 @@ def get_empty_mask(cls):

def get_shifted_object_annotation(self):
if self.mask:
shifted_mask = self.mask.get_shifted_mask()
return ObjectAnnotation(
bbox=self.bbox.get_shifted_box().to_xyxy(),
category_id=self.category.id,
bool_mask=self.mask.get_shifted_mask().bool_mask,
segmentation=shifted_mask.segmentation,
category_name=self.category.name,
shift_amount=[0, 0],
full_shape=self.mask.get_shifted_mask().full_shape,
full_shape=shifted_mask.full_shape,
)
else:
return ObjectAnnotation(
Expand Down
9 changes: 5 additions & 4 deletions sahi/models/detectron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.cv import get_bbox_from_bool_mask
from sahi.utils.cv import get_bbox_from_bool_mask, get_coco_segmentation_from_bool_mask
from sahi.utils.import_utils import check_requirements

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,12 +145,13 @@ def _create_object_prediction_list_from_original_predictions(
category_ids = category_ids[high_confidence_mask]
if masks is not None:
masks = masks[high_confidence_mask]

if masks is not None:
object_prediction_list = [
ObjectPrediction(
bbox=box.tolist() if mask is None else None,
bool_mask=mask.detach().cpu().numpy() if mask is not None else None,
segmentation=get_coco_segmentation_from_bool_mask(mask.detach().cpu().numpy())
if mask is not None
else None,
category_id=category_id.item(),
category_name=self.category_mapping[str(category_id.item())],
shift_amount=shift_amount,
Expand All @@ -164,7 +165,7 @@ def _create_object_prediction_list_from_original_predictions(
object_prediction_list = [
ObjectPrediction(
bbox=box.tolist(),
bool_mask=None,
segmentation=None,
category_id=category_id.item(),
category_name=self.category_mapping[str(category_id.item())],
shift_amount=shift_amount,
Expand Down
3 changes: 2 additions & 1 deletion sahi/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_coco_segmentation_from_bool_mask
from sahi.utils.import_utils import check_requirements, ensure_package_minimum_version

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -196,7 +197,7 @@ def _create_object_prediction_list_from_original_predictions(

object_prediction = ObjectPrediction(
bbox=bbox,
bool_mask=None,
segmentation=None,
category_id=category_id,
category_name=self.category_mapping[category_id],
shift_amount=shift_amount,
Expand Down
10 changes: 4 additions & 6 deletions sahi/models/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_bbox_from_bool_mask
from sahi.utils.cv import get_bbox_from_bool_mask, get_coco_segmentation_from_bool_mask
from sahi.utils.import_utils import check_requirements

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -220,14 +220,12 @@ def _create_object_prediction_list_from_original_predictions(
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""

try:
from pycocotools import mask as mask_utils

can_decode_rle = True
except ImportError:
can_decode_rle = False

original_predictions = self._original_predictions
category_mapping = self.category_mapping

Expand Down Expand Up @@ -275,13 +273,13 @@ def _create_object_prediction_list_from_original_predictions(
)
else:
bool_mask = mask

# check if mask is valid
# https://github.com/obss/sahi/discussions/696
if get_bbox_from_bool_mask(bool_mask) is None:
continue
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
else:
bool_mask = None
segmentation = None

# fix negative box coords
bbox[0] = max(0, bbox[0])
Expand All @@ -305,7 +303,7 @@ def _create_object_prediction_list_from_original_predictions(
bbox=bbox,
category_id=category_id,
score=score,
bool_mask=bool_mask,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
Expand Down
Loading
Loading