forked from openvinotoolkit/datumaro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SAMAutomaticMaskGeneration transform (openvinotoolkit#1168)
- Ticket no. 120781. Signed-off-by: Kim, Vinnam <[email protected]> Co-authored-by: Wonju Lee <[email protected]>
- Loading branch information
Showing
8 changed files
with
907 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
349 changes: 349 additions & 0 deletions
349
notebooks/19_automatic_instance_mask_gen_using_sam.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
181 changes: 181 additions & 0 deletions
181
src/datumaro/plugins/sam_transforms/automatic_mask_gen.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
"""Automatic mask generation using Segment Anything Model""" | ||
|
||
import os.path as osp | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
|
||
import datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_amg as sam_decoder_for_amg | ||
import datumaro.plugins.sam_transforms.interpreters.sam_encoder as sam_encoder_interp | ||
from datumaro.components.cli_plugin import CliPlugin | ||
from datumaro.components.dataset_base import DatasetItem, IDataset | ||
from datumaro.components.transformer import ModelTransform | ||
from datumaro.plugins.inference_server_plugin import OVMSLauncher, TritonLauncher | ||
from datumaro.plugins.inference_server_plugin.base import ( | ||
InferenceServerType, | ||
ProtocolType, | ||
TLSConfig, | ||
) | ||
from datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_amg import AMGMasks, AMGPoints | ||
|
||
__all__ = ["SAMAutomaticMaskGeneration"] | ||
|
||
|
||
class SAMAutomaticMaskGeneration(ModelTransform, CliPlugin): | ||
"""Produce instance segmentation masks automatically using Segment Anything Model (SAM). | ||
This transform can produce instance segmentation mask annotations for each given image. | ||
It samples single-point input prompts on a uniform 2D grid over the image. | ||
For each prompt, SAM can predict multiple masks. After obtaining the mask candidates, | ||
it post-processes them using the given parameters to improve quality and remove duplicates. | ||
It uses the Segment Anything Model deployed in the OpenVINO™ Model Server | ||
or NVIDIA Triton™ Inference Server instance. To launch the server instance, | ||
please see the guide in this link: | ||
https://github.com/openvinotoolkit/datumaro/tree/develop/docker/segment-anything/README.md | ||
Parameters: | ||
extractor: Dataset to transform | ||
inference_server_type: Inference server type: | ||
`InferenceServerType.ovms` or `InferenceServerType.triton` | ||
host: Host address of the server instance | ||
port: Port number of the server instance | ||
timeout: Timeout limit during communication between the client and the server instance | ||
tls_config: Configuration required if the server instance is in the secure mode | ||
protocol_type: Communication protocol type with the server instance | ||
num_workers: The number of worker threads to use for parallel inference. | ||
Set to 0 for single-process mode. Default is 0. | ||
points_per_side (int): The number of points to be sampled | ||
along one side of the image. The total number of points is | ||
points_per_side**2 on a uniform 2d grid. | ||
points_per_batch (int): Sets the number of points run simultaneously | ||
by the model. Higher numbers may be faster but use more GPU memory. | ||
pred_iou_thresh (float): A filtering threshold in [0,1], using the | ||
model's predicted mask quality. | ||
stability_score_thresh (float): A filtering threshold in [0,1], using | ||
the stability of the mask under changes to the cutoff used to binarize | ||
the model's mask predictions. | ||
stability_score_offset (float): The amount to shift the cutoff when | ||
calculated the stability score. | ||
box_nms_thresh (float): The box IoU cutoff used by non-maximal | ||
suppression to filter duplicate masks. | ||
min_mask_region_area (int): If >0, postprocessing will be applied | ||
to remove the binary mask which has the number of 1s less than min_mask_region_area. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
extractor: IDataset, | ||
inference_server_type: InferenceServerType = InferenceServerType.ovms, | ||
host: str = "localhost", | ||
port: int = 9000, | ||
timeout: float = 10.0, | ||
tls_config: Optional[TLSConfig] = None, | ||
protocol_type: ProtocolType = ProtocolType.grpc, | ||
num_workers: int = 0, | ||
points_per_side: int = 32, | ||
points_per_batch: int = 128, | ||
mask_threshold: float = 0.0, | ||
pred_iou_thresh: float = 0.88, | ||
stability_score_thresh: float = 0.95, | ||
stability_score_offset: float = 1.0, | ||
box_nms_thresh: float = 0.7, | ||
min_mask_region_area: int = 0, | ||
): | ||
if inference_server_type == InferenceServerType.ovms: | ||
launcher_cls = OVMSLauncher | ||
elif inference_server_type == InferenceServerType.triton: | ||
launcher_cls = TritonLauncher | ||
else: | ||
raise ValueError(inference_server_type) | ||
|
||
self._sam_encoder_launcher = launcher_cls( | ||
model_name="sam_encoder", | ||
model_interpreter_path=osp.abspath(sam_encoder_interp.__file__), | ||
model_version=1, | ||
host=host, | ||
port=port, | ||
timeout=timeout, | ||
tls_config=tls_config, | ||
protocol_type=protocol_type, | ||
) | ||
self._sam_decoder_launcher = launcher_cls( | ||
model_name="sam_decoder", | ||
model_interpreter_path=osp.abspath(sam_decoder_for_amg.__file__), | ||
model_version=1, | ||
host=host, | ||
port=port, | ||
timeout=timeout, | ||
tls_config=tls_config, | ||
protocol_type=protocol_type, | ||
) | ||
|
||
self.points_per_side = points_per_side | ||
self.points_per_batch = points_per_batch | ||
self.mask_threshold = mask_threshold | ||
self.pred_iou_thresh = pred_iou_thresh | ||
self.stability_score_offset = stability_score_offset | ||
self.stability_score_thresh = stability_score_thresh | ||
self.box_nms_thresh = box_nms_thresh | ||
self.min_mask_region_area = min_mask_region_area | ||
|
||
super().__init__( | ||
extractor, | ||
launcher=self._sam_encoder_launcher, | ||
batch_size=1, | ||
append_annotation=False, | ||
num_workers=num_workers, | ||
) | ||
|
||
@property | ||
def points_per_side(self) -> int: | ||
return self._points_per_side | ||
|
||
@points_per_side.setter | ||
def points_per_side(self, points_per_side: int) -> None: | ||
points_y = (np.arange(points_per_side) + 0.5) / points_per_side | ||
points_x = (np.arange(points_per_side) + 0.5) / points_per_side | ||
|
||
points_x = np.tile(points_x[None, :], (points_per_side, 1)) | ||
points_y = np.tile(points_y[:, None], (1, points_per_side)) | ||
self._points_grid = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) | ||
self._points_per_side = points_per_side | ||
|
||
def _process_batch( | ||
self, | ||
batch: List[DatasetItem], | ||
) -> List[DatasetItem]: | ||
img_embeds = self._sam_encoder_launcher.launch( | ||
batch=[item for item in batch if self._sam_encoder_launcher.type_check(item)] | ||
) | ||
|
||
items = [] | ||
for item, img_embed in zip(batch, img_embeds): | ||
amg_masks: List[AMGMasks] = [] | ||
|
||
for i in range(0, len(self._points_grid), self.points_per_batch): | ||
amg_points = [AMGPoints(points=self._points_grid[i : i + self.points_per_batch])] | ||
item_to_decode = item.wrap(annotations=amg_points + img_embed) | ||
|
||
# Nested list of mask [[mask_0, ...]] | ||
nested_masks: List[List[AMGMasks]] = self._sam_decoder_launcher.launch( | ||
[item_to_decode], | ||
stack=False, | ||
) | ||
amg_masks += nested_masks[0] | ||
|
||
mask_anns = AMGMasks.cat(amg_masks).postprocess( | ||
mask_threshold=self.mask_threshold, | ||
pred_iou_thresh=self.pred_iou_thresh, | ||
stability_score_offset=self.stability_score_offset, | ||
stability_score_thresh=self.stability_score_thresh, | ||
box_nms_thresh=self.box_nms_thresh, | ||
min_mask_region_area=self.min_mask_region_area, | ||
) | ||
|
||
items.append(item.wrap(annotations=mask_anns)) | ||
|
||
return items |
Oops, something went wrong.