diff --git a/CHANGELOG.md b/CHANGELOG.md index 68b075dca53..933559f76d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file. ### New features -- Add zero-shot visual prompting (https://github.com/openvinotoolkit/training_extensions/pull/2616) +- Add zero-shot visual prompting (, ) ### Enhancements diff --git a/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/__init__.py b/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/__init__.py index 1c22c536057..d1ed6bc32e2 100644 --- a/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/__init__.py +++ b/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions # and limitations under the License. -from .openvino_models import Decoder, ImageEncoder # noqa: F401 +from .openvino_models import Decoder, ImageEncoder, PromptGetter # noqa: F401 diff --git a/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py b/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py index ee18acd4bd6..9667dd7e45e 100644 --- a/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py +++ b/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py @@ -59,6 +59,20 @@ def preprocess( return dict_inputs, meta +class PromptGetter(ImageModel): + """PromptGetter class for zero-shot visual prompting of openvino model wrapper.""" + + __model__ = "prompt_getter" + + @classmethod + def parameters(cls) -> Dict[str, Any]: # noqa: D102 + parameters = super().parameters() + parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)}) + parameters.update({"sim_threshold": NumericalValue(value_type=float, default_value=0.5, min=0, max=1)}) + parameters.update({"num_bg_points": NumericalValue(value_type=int, default_value=1, min=0, max=1024)}) + return parameters + + class Decoder(SegmentationModel): """Decoder class for visual prompting of openvino model wrapper.""" @@ -76,6 +90,7 @@ def __init__( def parameters(cls): # noqa: D102 parameters = super().parameters() parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)}) + parameters.update({"mask_threshold": NumericalValue(value_type=float, default_value=0.0, min=0, max=1)}) return parameters def _get_outputs(self): @@ -174,7 +189,7 @@ def resize_and_crop(self, soft_prediction: np.ndarray, original_size: np.ndarray ) prepadded_size = self.get_padded_size(original_size, self.image_size).astype(np.int64) - resized_cropped_soft_prediction = resized_soft_prediction[..., : prepadded_size[0], : prepadded_size[1]] + resized_cropped_soft_prediction = resized_soft_prediction[: prepadded_size[0], : prepadded_size[1], ...] original_size = original_size.astype(np.int64) h, w = original_size diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py index 51b78e56880..476a2c09d69 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py @@ -207,7 +207,7 @@ def get_prompts(dataset_item: DatasetItemEntity, dataset_labels: List[LabelEntit bboxes = np.array(bboxes) return dict( - original_size=(height, width), + original_size=np.array((height, width), dtype=np.int64), gt_masks=gt_masks, bboxes=bboxes, points=points, # TODO (sungchul): update point information @@ -247,6 +247,20 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: class OTXZeroShotVisualPromptingDataset(OTXVisualPromptingDataset): """Visual Prompting for Zero-shot learning Dataset Adaptor.""" + def __init__( + self, + dataset: DatasetEntity, + image_size: int, + mean: List[float], + std: List[float], + generate_point: bool = False, + generate_bbox: bool = False, + **kwargs, + ) -> None: + super().__init__(dataset, image_size, mean, std, offset_bbox=0) + self.generate_point = generate_point + self.generate_bbox = generate_bbox + def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: """Get dataset item. @@ -288,7 +302,7 @@ def __init__( self.config = config self.dataset = dataset self.train_type = train_type - # self.kwargs = {} + self.kwargs = {} if self.train_type == TrainType.Zeroshot: # check zero-shot configs if self.config.get("train_batch_size", 1) != 1: @@ -300,12 +314,12 @@ def __init__( ) self.config["train_batch_size"] = 1 - # self.kwargs.update( - # { - # "generate_point": self.config.get("generate_point", False), - # "generate_bbox": self.config.get("generate_bbox", False), - # } - # ) + self.kwargs.update( + { + "generate_point": self.config.get("generate_point", False), + "generate_bbox": self.config.get("generate_bbox", False), + } + ) self.train_otx_dataset: DatasetEntity self.val_otx_dataset: DatasetEntity @@ -331,7 +345,7 @@ def setup(self, stage: Optional[str] = None) -> None: mean=mean, std=std, offset_bbox=self.config.offset_bbox, - # **self.kwargs, + **self.kwargs, ) # self.val_dataset = None @@ -347,11 +361,7 @@ def setup(self, stage: Optional[str] = None) -> None: if stage == "predict": self.predict_dataset = self.DATASETS[self.train_type]( - dataset=self.dataset, - image_size=image_size, - mean=mean, - std=std, - # **self.kwargs + dataset=self.dataset, image_size=image_size, mean=mean, std=std, **self.kwargs ) def summary(self): diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py index fd9b1a3057b..feb4917621d 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py @@ -4,7 +4,6 @@ # All rights reserved. # -from copy import deepcopy from typing import Any, Dict, List, Tuple, Union import numpy as np @@ -76,9 +75,9 @@ def apply_coords( old_h, old_w = original_size new_h, new_w = cls.get_preprocess_shape(original_size[0], original_size[1], target_length) if isinstance(coords, np.ndarray): - coords = deepcopy(coords).astype(np.float32) + coords = coords.astype(float) else: - coords = deepcopy(coords).to(torch.float32) + coords = coords.to(torch.float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py index 4b1c507a7f8..f53fb4b3457 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py @@ -23,25 +23,26 @@ def collate_fn(batch: List[Any]) -> Dict: Dict: Collated batch data. """ - def _convert_empty_to_none(x: str) -> List: + def _convert_empty_to_none(x: str, dtype: torch.dtype = torch.float32) -> List: """Convert empty list to None. Args: x (str): Key of batch data. + dtype (torch.dtype): Dtype to be applied to tensors. Returns: List: List of batch data. """ func = torch.stack if x == "gt_masks" else torch.tensor - items = [func(item[x]) for item in batch if item[x] is not None] + items = [func(item[x]).to(dtype) for item in batch if item[x] is not None] return None if len(items) == 0 else items index = [item["index"] for item in batch] images = torch.stack([item["images"] for item in batch]) bboxes = _convert_empty_to_none("bboxes") points = None # TBD - gt_masks = _convert_empty_to_none("gt_masks") - original_size = [item["original_size"] for item in batch] + gt_masks = _convert_empty_to_none("gt_masks", torch.int32) + original_size = _convert_empty_to_none("original_size") padding = [item["padding"] for item in batch] path = [item["path"] for item in batch] labels = [item["labels"] for item in batch] diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py index 3b84daa72b8..9581d21ab41 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py @@ -10,7 +10,7 @@ import re from collections import OrderedDict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from omegaconf import DictConfig @@ -334,24 +334,29 @@ def select_masks(self, masks: Tensor, iou_preds: Tensor, num_points: int) -> Tup return masks, iou_preds - def mask_postprocessing(self, masks: Tensor, orig_size: Tensor) -> Tensor: + @staticmethod + def mask_postprocessing(masks: Tensor, input_size: int, orig_size: Tensor) -> Tensor: """Postprocesses the predicted masks. Args: masks (Tensor): A batch of predicted masks with shape Bx1xHxW. + input_size (int): The size of the image input to the model, in (H, W) format. + Used to remove padding. orig_size (Tensor): The original image size with shape Bx2. Returns: masks (Tensor): The postprocessed masks with shape Bx1xHxW. """ - masks = F.interpolate( - masks, - size=(self.config.model.image_size, self.config.model.image_size), - mode="bilinear", - align_corners=False, - ) - prepadded_size = self.resize_longest_image_size(orig_size, self.config.model.image_size).to(torch.int64) + def resize_longest_image_size(input_image_size: Tensor, longest_side: int) -> Tensor: + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + masks = F.interpolate(masks, size=(input_size, input_size), mode="bilinear", align_corners=False) + + prepadded_size = resize_longest_image_size(orig_size, input_size) masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_size = orig_size.to(torch.int64) @@ -359,22 +364,6 @@ def mask_postprocessing(self, masks: Tensor, orig_size: Tensor) -> Tensor: masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) return masks - def resize_longest_image_size(self, input_image_size: Tensor, longest_side: int) -> Tensor: - """Resizes the longest side of the image to the given size. - - Args: - input_image_size (Tensor): The original image size with shape Bx2. - longest_side (int): The size of the longest side. - - Returns: - transformed_size (Tensor): The transformed image size with shape Bx2. - """ - input_image_size = input_image_size.to(torch.float32) - scale = longest_side / torch.max(input_image_size) - transformed_size = scale * input_image_size - transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) - return transformed_size - ###################################################### # forward for training/validation/prediction # ###################################################### @@ -556,8 +545,8 @@ def predict_step(self, batch, batch_idx) -> Dict[str, Tensor]: def postprocess_masks( masks: Tensor, input_size: Tuple[int, int], - padding: Tuple[int, ...], - original_size: Tuple[int, int], + padding: Union[Tuple[int, ...], Tensor], + original_size: Union[Tuple[int, int], Tensor], ) -> Tensor: """Remove padding and upscale masks to the original image size. @@ -565,17 +554,17 @@ def postprocess_masks( masks (Tensor): Predicted masks from the mask_decoder with (N, 1, H/downsized_ratio, W/downsized_ratio). input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. - padding (tuple(int, int, int, int), optional): The padding applied to the image before input to the model, + padding (tuple(int, int, int, int), Tensor): The padding applied to the image before input to the model, in (left, top, right, bottom) format. - original_size (tuple(int, int)): The original size of the image before resizing for input to the model, - in (H, W) format. + original_size (tuple(int, int), Tensor): The original size of the image before resizing + for input to the model, in (H, W) format. Returns: (Tensor): Postprocessed masks in NxHxW format, where (H, W) is given by original_size. """ masks = F.interpolate(masks, input_size, mode="bilinear", align_corners=False) masks = masks[..., : input_size[0] - padding[3], : input_size[1] - padding[2]] - masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + masks = F.interpolate(masks, [int(o) for o in original_size], mode="bilinear", align_corners=False) return masks.squeeze(1) def configure_optimizers(self) -> optim: diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py index a915862523c..545bf8ee32b 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py @@ -5,7 +5,7 @@ from collections import OrderedDict, defaultdict from copy import deepcopy -from typing import Any, DefaultDict, Dict, List, Optional, Tuple +from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union import torch from omegaconf import DictConfig @@ -27,15 +27,26 @@ class PromptGetter(nn.Module): default_threshold_reference = 0.3 default_threshold_target = 0.65 - def __init__(self, image_size: int) -> None: + def __init__( + self, + image_size: int, + reference_feats: Optional[torch.Tensor] = None, + reference_prompts: Optional[torch.Tensor] = None, + downsizing: int = 64, + ) -> None: super().__init__() self.image_size = image_size - self.initialize() + self.downsizing = downsizing + self.initialize(reference_feats, reference_prompts) + + self.zero_tensor = torch.tensor(0) - def initialize(self) -> None: + def initialize( + self, reference_feats: Optional[torch.Tensor] = None, reference_prompts: Optional[torch.Tensor] = None + ) -> None: """Initialize reference features and prompts.""" - self.reference_feats: Dict[int, torch.Tensor] = {} - self.reference_prompts: Dict[int, torch.Tensor] = {} + self.reference_feats = reference_feats + self.reference_prompts = reference_prompts def set_default_thresholds(self, default_threshold_reference: float, default_threshold_target: float) -> None: """Set default thresholds.""" @@ -44,75 +55,134 @@ def set_default_thresholds(self, default_threshold_reference: float, default_thr def set_reference(self, label: ScoredLabel, reference_feats: torch.Tensor, reference_prompts: torch.Tensor) -> None: """Set reference features and prompts.""" - self.reference_feats[int(label.id_)] = reference_feats - self.reference_prompts[int(label.id_)] = reference_prompts + if self.reference_feats is None: + self.reference_feats = torch.zeros_like(reference_feats).unsqueeze(0) + if self.reference_prompts is None: + self.reference_prompts = torch.zeros_like(reference_prompts).unsqueeze(0) + + for idx in range(int(label.id_) + 1): + if idx == int(label.id_): + while self.reference_feats.shape[0] - 1 < idx: + self.reference_feats = torch.cat( + (self.reference_feats, torch.zeros_like(reference_feats).unsqueeze(0)), dim=0 + ) + self.reference_prompts = torch.cat( + (self.reference_prompts, torch.zeros_like(reference_prompts).unsqueeze(0)), dim=0 + ) + self.reference_feats[idx] = reference_feats + self.reference_prompts[idx] = reference_prompts def forward( self, image_embeddings: torch.Tensor, - padding: Tuple[int, ...], - original_size: Tuple[int, int], - ) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]: + original_size: torch.Tensor, + threshold: torch.Tensor = torch.tensor([[0.0]], dtype=torch.float32), + num_bg_points: torch.Tensor = torch.tensor([[1]], dtype=torch.int64), + ) -> Tuple[torch.Tensor, torch.Tensor]: """Get prompt candidates.""" + total_points_scores: torch.Tensor + total_bg_coords: torch.Tensor + + device = image_embeddings.device + threshold = threshold.to(device) + for label in torch.arange(self.reference_feats.shape[0]): + points_scores, bg_coords = self.get_prompt_candidates( + image_embeddings=image_embeddings, + label=label, + original_size=original_size, + threshold=threshold, + num_bg_points=num_bg_points, + device=device, + ) + if label == 0: + total_points_scores = points_scores.unsqueeze(0) + total_bg_coords = bg_coords.unsqueeze(0) + else: + pad_size = torch.tensor(points_scores.shape[0] - total_points_scores.shape[1]) + pad_tot = torch.max(self.zero_tensor, pad_size) + pad_cur = torch.max(self.zero_tensor, -pad_size) + + total_points_scores = F.pad(total_points_scores, (0, 0, 0, pad_tot, 0, 0), value=-1) + points_scores = F.pad(points_scores, (0, 0, 0, pad_cur), value=-1) + + total_points_scores = torch.cat((total_points_scores, points_scores.unsqueeze(0)), dim=0) + total_bg_coords = torch.cat((total_bg_coords, bg_coords.unsqueeze(0)), dim=0) + + return total_points_scores, total_bg_coords + + def get_prompt_candidates( + self, + image_embeddings: torch.Tensor, + label: torch.Tensor, + original_size: torch.Tensor, + threshold: torch.Tensor = torch.tensor([[0.0]], dtype=torch.float32), + num_bg_points: torch.Tensor = torch.tensor([[1]], dtype=torch.int64), + device: torch.device = torch.device("cpu"), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get prompt candidates from given reference and target features.""" + assert original_size.dim() == 2 and threshold.dim() == 2 and num_bg_points.dim() == 2 + target_feat = image_embeddings.squeeze() c_feat, h_feat, w_feat = target_feat.shape - target_feat = self._preprocess_target_feat(target_feat, c_feat, h_feat, w_feat) - - prompts = {} - for label, reference_feat in self.reference_feats.items(): - sim = reference_feat.to(target_feat.device) @ target_feat - sim = sim.reshape(1, 1, h_feat, w_feat) - sim = ZeroShotSegmentAnything.postprocess_masks( - sim, (self.image_size, self.image_size), padding, original_size - ).squeeze() - - # threshold = 0.85 * sim.max() if num_classes > 1 else self.default_threshold_target - threshold = self.default_threshold_target - points_scores, bg_coords = self._point_selection(sim, original_size, threshold) - if points_scores is None: - # skip if there is no point with score > threshold - continue - prompts[label] = (points_scores, bg_coords) - return prompts - - def _preprocess_target_feat(self, target_feat: torch.Tensor, c_feat: int, h_feat: int, w_feat: int) -> torch.Tensor: target_feat = target_feat / target_feat.norm(dim=0, keepdim=True) target_feat = target_feat.reshape(c_feat, h_feat * w_feat) - return target_feat + + sim = self.reference_feats[label].to(device) @ target_feat + sim = sim.reshape(1, 1, h_feat, w_feat) + sim = ZeroShotSegmentAnything.mask_postprocessing(sim, self.image_size, original_size[0]) + + threshold = (threshold == 0) * self.default_threshold_target + threshold + points_scores, bg_coords = self._point_selection( + mask_sim=sim[0, 0], + original_size=original_size[0], + threshold=threshold, + num_bg_points=num_bg_points, + ) + + return points_scores, bg_coords def _point_selection( self, mask_sim: torch.Tensor, - original_size: Tuple[int, int], - threshold: float, - num_bg_points: int = 1, - downsizing: int = 16, + original_size: torch.Tensor, + threshold: torch.Tensor, + num_bg_points: torch.Tensor = torch.tensor([[1]], dtype=torch.int64), ) -> Tuple[torch.Tensor, torch.Tensor]: """Select point used as point prompts.""" _, w_sim = mask_sim.shape # Top-last point selection - bg_indices = mask_sim.flatten().topk(num_bg_points, largest=False)[1] + bg_indices = mask_sim.flatten().topk(num_bg_points[0, 0], largest=False)[1] bg_x = (bg_indices // w_sim).unsqueeze(0) bg_y = bg_indices - bg_x * w_sim bg_coords = torch.cat((bg_y, bg_x), dim=0).permute(1, 0) - bg_coords = bg_coords + bg_coords = bg_coords.to(torch.float32) point_coords = torch.where(mask_sim > threshold) - if len(point_coords[0]) == 0: - return None, None - fg_coords_scores = torch.stack(point_coords[::-1] + (mask_sim[point_coords],), dim=0).T - max_len = max(original_size) - ratio = self.image_size / max_len - _, width = map(lambda x: int(x * ratio), original_size) - n_w = width // downsizing + ratio = self.image_size / original_size.max() + width = (original_size[1] * ratio).to(torch.int64) + n_w = width // self.downsizing - res = (fg_coords_scores[:, 1] * ratio // downsizing * n_w + fg_coords_scores[:, 0] * ratio // downsizing).to( - torch.int32 + # get grid numbers + idx_grid = ( + fg_coords_scores[:, 1] * ratio // self.downsizing * n_w + fg_coords_scores[:, 0] * ratio // self.downsizing ) - points_scores = torch.stack([fg_coords_scores[res == r][0] for r in torch.unique(res)], dim=0) + idx_grid_unique = torch.unique( + idx_grid.to(torch.int64) + ) # unique op only supports INT64, INT8, FLOAT, STRING in ORT + + # get matched indices + matched_matrix = idx_grid.unsqueeze(-1) == idx_grid_unique # (totalN, uniqueN) + + # sample fg_coords_scores matched by matched_matrix + matched_grid = fg_coords_scores.unsqueeze(1) * matched_matrix.unsqueeze(-1) + + # sample the highest score one of the samples that are in the same grid + points_scores = matched_grid[matched_grid[..., -1].argsort(dim=0, descending=True)[0]].diagonal().T + + # sort by the highest score points_scores = points_scores[torch.argsort(points_scores[:, -1], descending=True)] return points_scores, bg_coords @@ -147,16 +217,18 @@ def __init__(self, config: Optional[DictConfig] = None, state_dict: Optional[Ord super().__init__(config, state_dict) - self.prompt_getter = PromptGetter(image_size=config.model.image_size) - self.prompt_getter.initialize() + self.prompt_getter = PromptGetter( + image_size=config.model.image_size, + reference_feats=prompt_getter_reference_feats, + reference_prompts=prompt_getter_reference_prompts, + ) self.prompt_getter.set_default_thresholds( - config.model.default_threshold_reference, config.model.default_threshold_target + default_threshold_reference=config.model.default_threshold_reference, + default_threshold_target=config.model.default_threshold_target, ) - if prompt_getter_reference_feats: - self.prompt_getter.reference_feats = prompt_getter_reference_feats - if prompt_getter_reference_prompts: - self.prompt_getter.reference_prompts = prompt_getter_reference_prompts + self.point_labels_box = torch.tensor([[2, 3]], dtype=torch.float32) + self.has_mask_inputs = [torch.tensor([[0.0]]), torch.tensor([[1.0]])] def set_default_config(self) -> DictConfig: """Set default config when using independently.""" @@ -181,8 +253,8 @@ def learn( self, images: torch.Tensor, processed_prompts: Dict[ScoredLabel, List[Dict[str, torch.Tensor]]], - padding: Tuple[int, ...], - original_size: Tuple[int, int], + padding: Union[Tuple[int, ...], torch.Tensor], + original_size: torch.Tensor, ) -> None: """Get reference features. @@ -194,8 +266,8 @@ def learn( images (torch.Tensor): Given images for reference features. processed_prompts (Dict[ScoredLabel, List[Dict[str, torch.Tensor]]]): The whole class-wise prompts processed at _preprocess_prompts. - padding (Tuple[int, ...]): Padding size. - original_size (Tuple[int, int]): Original image size. + padding (Union[Tuple[int, ...], torch.Tensor]): Padding size. + original_size (torch.Tensor): Original image size. """ assert images.shape[0] == 1, "Only single batch is supported." @@ -212,25 +284,43 @@ def learn( # generate reference mask # TODO (sungchul): ensemble multi reference features (current : use merged masks) - reference_prompt = torch.zeros(original_size, dtype=torch.uint8, device=images.device) + reference_prompt = torch.zeros(*map(int, original_size), dtype=torch.uint8, device=self.device) for input_prompt in input_prompts: if "annotation" in input_prompt: # directly use annotation information as a mask reference_prompt[input_prompt.get("annotation") == 1] += 1 else: merged_input_prompts = self._merge_prompts(label, input_prompt, processed_prompts) - masks, scores, logits = self._predict_mask( + # TODO (sungchul): they must be processed in `_merge_prompts` + # and it is required to be expanded to other prompts. + point_coords = [] + point_labels = [] + if "box" in merged_input_prompts: + for box in merged_input_prompts["box"]: + point_coords.append(box[:2]) + point_labels.append(2) + point_coords.append(box[2:]) + point_labels.append(3) + + if "points" in merged_input_prompts: + raise NotImplementedError() + + if "annotations" in merged_input_prompts: + raise NotImplementedError() + + point_coords = torch.stack(point_coords, dim=0).unsqueeze(0) + point_labels = torch.tensor([point_labels], device=self.device) + masks = self._predict_masks( image_embeddings=image_embeddings, - input_prompts=merged_input_prompts, - padding=padding, + point_coords=point_coords, + point_labels=point_labels, original_size=original_size, - multimask_output=True, + is_cascade=False, ) - best_idx = torch.argmax(scores) - reference_prompt[masks[0, best_idx]] += 1 + reference_prompt[masks] += 1 reference_prompt = torch.clip(reference_prompt, 0, 1) - ref_mask = torch.tensor(reference_prompt, dtype=torch.float32) + ref_mask = reference_prompt.to(torch.float32) reference_feat = None default_threshold_reference = deepcopy(self.prompt_getter.default_threshold_reference) while reference_feat is None: @@ -240,11 +330,11 @@ def learn( ) default_threshold_reference -= 0.05 - self.prompt_getter.set_reference(label, reference_feat.detach().cpu(), reference_prompt.detach().cpu()) + self.prompt_getter.set_reference(label, reference_feat, reference_prompt) @torch.no_grad() def infer( - self, images: torch.Tensor, padding: Tuple[int, ...], original_size: Tuple[int, int] + self, images: torch.Tensor, original_size: torch.Tensor ) -> List[List[DefaultDict[int, List[torch.Tensor]]]]: """Zero-shot inference with reference features. @@ -252,8 +342,7 @@ def infer( Args: images (torch.Tensor): Given images for target results. - padding (Tuple[int, ...]): Padding size. - original_size (Tuple[int, int]): Original image size. + original_size (torch.Tensor): Original image size. Returns: (List[List[DefaultDict[int, List[torch.Tensor]]]]): Target results. @@ -264,20 +353,21 @@ def infer( assert images.shape[0] == 1, "Only single batch is supported." total_results = [] - # num_classes = len(self.reference_feats.keys()) for image in images: if image.ndim == 3: image = image.unsqueeze(0) image_embeddings = self.image_encoder(images) - prompts = self.prompt_getter( - image_embeddings=image_embeddings, padding=padding, original_size=original_size + total_points_scores, total_bg_coords = self.prompt_getter( + image_embeddings=image_embeddings, original_size=original_size ) predicted_masks: defaultdict = defaultdict(list) used_points: defaultdict = defaultdict(list) - for label, (points_scores, bg_coords) in prompts.items(): + for label, (points_scores, bg_coords) in enumerate(zip(total_points_scores, total_bg_coords)): for points_score in points_scores: + if points_score[-1] == -1: + continue x, y = points_score[:2] is_done = False for pm in predicted_masks.get(label, []): @@ -288,51 +378,80 @@ def infer( if is_done: continue - mask, used_point_score = self( + point_coords = torch.cat((points_score[:2].unsqueeze(0), bg_coords), dim=0).unsqueeze(0) + point_coords = ResizeLongestSide.apply_coords( + point_coords, original_size[0], self.config.model.image_size + ) + point_labels = torch.tensor( + [1] + [0] * len(bg_coords), dtype=torch.float32, device=self.device + ).unsqueeze(0) + mask = self._predict_masks( image_embeddings=image_embeddings, - points_score=points_score, - bg_coords=bg_coords, - padding=padding, - original_size=original_size, + point_coords=point_coords, + point_labels=point_labels, + original_size=original_size[0], ) - predicted_masks[label].append(mask) - used_points[label].append(used_point_score) + predicted_masks[label].append(mask.detach().cpu()) + used_points[label].append(points_score.detach().cpu()) total_results.append([predicted_masks, used_points]) return total_results - @torch.no_grad() - def forward( + def _predict_masks( self, image_embeddings: torch.Tensor, - points_score: torch.Tensor, - bg_coords: torch.Tensor, - padding: Tuple[int, ...], - original_size: Tuple[int, int], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Predict point prompts and predicted masks. - - Args: - image_embeddings (torch.Tensor): The image embedding with a batch index of length 1. - points_score (torch.Tensor): Foreground point prompts from point selection algorithm. - bg_coords (torch.Tensor): Background point prompts from point selection algorithm. - padding (Tuple[int, ...]): Padding size. - original_size (Tuple[int, int]): Original image size. - - Returns: - (Tuple[torch.Tensor, torch.Tensor]): Predicted masks and used points with corresponding score. - """ - point_coords = torch.cat((points_score[:2].unsqueeze(0), bg_coords), dim=0).unsqueeze(0) - point_coords = ResizeLongestSide.apply_coords(point_coords, original_size, self.config.model.image_size) - point_labels = torch.tensor([1] + [0] * len(bg_coords), dtype=torch.int32).unsqueeze(0) - mask = self._predict_target_mask( - image_embeddings=image_embeddings, - input_prompts={"points": (point_coords, point_labels)}, - padding=padding, - original_size=original_size, - ) + point_coords: torch.Tensor, + point_labels: torch.Tensor, + original_size: torch.Tensor, + is_cascade: bool = True, + ) -> torch.Tensor: + """Predict target masks.""" + logits: torch.Tensor + scores: torch.Tensor + for i in range(3): + if i == 0: + # First-step prediction + mask_input = torch.zeros(1, 1, *map(lambda x: x * 4, image_embeddings.shape[2:]), device=self.device) + has_mask_input = self.has_mask_inputs[0].to(self.device) + + elif is_cascade and i == 1: + # Cascaded Post-refinement-1 + mask_input, masks = self._postprocess_masks(logits, scores, original_size) # noqa: F821 + if masks.sum() == 0: + return masks + + has_mask_input = self.has_mask_inputs[1].to(self.device) + + elif is_cascade and i == 2: + # Cascaded Post-refinement-2 + mask_input, masks = self._postprocess_masks(logits, scores, original_size) # noqa: F821 + if masks.sum() == 0: + return masks + + has_mask_input = self.has_mask_inputs[1].to(self.device) + coords = torch.nonzero(masks) + y, x = coords[:, 0], coords[:, 1] + point_coords = torch.cat( + ( + point_coords, + torch.tensor( + [[[x.min(), y.min()], [x.max(), y.max()]]], dtype=torch.float32, device=self.device + ), + ), + dim=1, + ) + point_labels = torch.cat((point_labels, self.point_labels_box.to(self.device)), dim=1) + + scores, logits = self( + image_embeddings=image_embeddings, + point_coords=point_coords, + point_labels=point_labels, + mask_input=mask_input, + has_mask_input=has_mask_input, + ) - return mask.detach().cpu().to(torch.uint8), points_score.detach().cpu() + _, masks = self._postprocess_masks(logits, scores, original_size) + return masks def training_step(self, batch, batch_idx) -> None: """Training step for `learn`.""" @@ -355,9 +474,7 @@ def training_step(self, batch, batch_idx) -> None: def predict_step(self, batch, batch_idx): """Predict step for `infer`.""" - results = self.infer( - images=batch["images"], padding=batch.get("padding")[0], original_size=batch.get("original_size")[0] - ) + results = self.infer(images=batch["images"], original_size=batch.get("original_size")[0].unsqueeze(0)) return [result[0] for result in results] # tmp: only mask def _preprocess_prompts( @@ -399,7 +516,11 @@ def _preprocess_prompts( return processed_prompts def _generate_masked_features( - self, feats: torch.Tensor, masks: torch.Tensor, threshold_mask: float, padding: Optional[Tuple[int, ...]] = None + self, + feats: torch.Tensor, + masks: torch.Tensor, + threshold_mask: float, + padding: Optional[Union[Tuple[int, ...], torch.Tensor]] = None, ) -> Tuple[torch.Tensor, ...]: """Generate masked features. @@ -407,7 +528,7 @@ def _generate_masked_features( feats (torch.Tensor): Raw reference features. It will be filtered with masks. masks (torch.Tensor): Reference masks used to filter features. threshold_mask (float): Threshold to control masked region. - padding (Tuple[int, ...], optional): Padding size. + padding (Union[Tuple[int, ...], torch.Tensor], optional): Padding size. Returns: (torch.Tensor): Masked features. @@ -422,7 +543,7 @@ def _generate_masked_features( # Post-process masks masks = F.interpolate(masks.unsqueeze(0).unsqueeze(0), size=resized_size, mode="bilinear").squeeze() - masks = self._preprocess_mask(masks) + masks = self._preprocess_masks(masks) masks = F.interpolate(masks.unsqueeze(0).unsqueeze(0), size=feats.shape[0:2], mode="bilinear").squeeze() # Target feature extraction @@ -436,7 +557,7 @@ def _generate_masked_features( return masked_feat - def _preprocess_mask(self, x: torch.Tensor) -> torch.Tensor: + def _preprocess_masks(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input. Args: @@ -452,6 +573,32 @@ def _preprocess_mask(self, x: torch.Tensor) -> torch.Tensor: x = F.pad(x, (0, padw, 0, padh)) return x + def _postprocess_masks( + self, + logits: torch.Tensor, + scores: torch.Tensor, + original_size: torch.Tensor, + ): + """Post-process masks for cascaded post-refinements.""" + high_res_masks = self.mask_postprocessing(logits, self.config.model.image_size, original_size) + masks = high_res_masks > self.config.model.mask_threshold + + # skip the first index components + scores, masks, logits = map(lambda x: x[:, 1:], (scores, masks, logits)) + + # filter zero masks + while len(scores[0]) > 0 and masks[0, (best_idx := torch.argmax(scores[0]))].sum() == 0: + scores, masks, logits = map( + lambda x: torch.cat((x[:, :best_idx], x[:, best_idx + 1 :]), dim=1), (scores, masks, logits) + ) + + if len(scores[0]) == 0: + # all predicted masks were zero masks, ignore them. + return None, torch.zeros((self.config.model.image_size, self.config.model.image_size), device="cpu") + + best_idx = torch.argmax(scores[0]) + return logits[:, best_idx], masks[0, best_idx] + def _update_value(self, target: Dict[str, Any], key: str, value: torch.Tensor) -> None: """Update tensor to target dictionary. @@ -506,98 +653,6 @@ def _merge_prompts( ) return merged_input_prompts - def _predict_target_mask( - self, - image_embeddings: torch.Tensor, - input_prompts: Dict[str, Tuple[torch.Tensor, torch.Tensor]], - padding: Tuple[int, ...], - original_size: Tuple[int, int], - ) -> torch.Tensor: - """Predict target masks. - - Args: - image_embeddings (torch.Tensor): The image embedding with a batch index of length 1. - input_prompts (Dict[str, Tuple[torch.Tensor, torch.Tensor]]): Dictionary including point, box, - and mask prompts. index=1 of tuple is point labels which indicate whether foreground or background. - padding (Tuple[int, ...]): Padding size. - original_size (Tuple[int, int]): Original image size. - - Return: - (torch.Tensor): Predicted mask. - """ - # First-step prediction - _, _, logits = self._predict_mask( - image_embeddings, input_prompts, padding, original_size, multimask_output=False - ) - best_idx = 0 - - # Cascaded Post-refinement-1 - input_prompts.update({"masks": logits[:, best_idx : best_idx + 1, :, :]}) - masks, scores, logits = self._predict_mask( - image_embeddings, input_prompts, padding, original_size, multimask_output=True - ) - best_idx = torch.argmax(scores) - - # Cascaded Post-refinement-2 - coords = torch.nonzero(masks[0, best_idx]) - y, x = coords[:, 0], coords[:, 1] - x_min = x.min() - x_max = x.max() - y_min = y.min() - y_max = y.max() - input_prompts.update( - { - "masks": logits[:, best_idx : best_idx + 1, :, :], - "box": torch.tensor([x_min, y_min, x_max, y_max], device=logits.device), - } - ) - masks, scores, _ = self._predict_mask( - image_embeddings, input_prompts, padding, original_size, multimask_output=True - ) - best_idx = torch.argmax(scores) - - return masks[0, best_idx] - - def _predict_mask( - self, - image_embeddings: torch.Tensor, - input_prompts: Dict[str, torch.Tensor], - padding: Tuple[int, ...], - original_size: Tuple[int, int], - multimask_output: bool = True, - ) -> Tuple[torch.Tensor, ...]: - """Predict target masks. - - Args: - image_embeddings (torch.Tensor): The image embedding with a batch index of length 1. - input_prompts (Dict[str, torch.Tensor]): Dictionary including point, box, and mask prompts. - padding (Tuple[int, ...]): Padding size. - original_size (Tuple[int, int]): Original image size. - multimask_output (bool): Whether getting multi mask outputs or not. Defaults to True. - - Return: - (Tuple[torch.Tensor, ...]): Predicted mask, score, and logit. - """ - sparse_embeddings, dense_embeddings = self.prompt_encoder( - points=input_prompts.get("points", None), - boxes=input_prompts.get("box", None), # TODO (sungchul): change key box -> boxes to use **input_prompts - masks=input_prompts.get("masks", None), - ) - - low_res_masks, scores = self.mask_decoder( - image_embeddings=image_embeddings, - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - ) - high_res_masks = self.postprocess_masks( - low_res_masks, (self.config.model.image_size, self.config.model.image_size), padding, original_size - ) - masks = high_res_masks > self.config.model.mask_threshold - - return masks, scores, low_res_masks - def set_metrics(self) -> None: """Skip set_metrics unused in zero-shot learning.""" pass diff --git a/src/otx/algorithms/visual_prompting/configs/base/configuration.py b/src/otx/algorithms/visual_prompting/configs/base/configuration.py index 44998684aec..d7383c28c69 100644 --- a/src/otx/algorithms/visual_prompting/configs/base/configuration.py +++ b/src/otx/algorithms/visual_prompting/configs/base/configuration.py @@ -102,6 +102,36 @@ class __Postprocessing(ParameterGroup): affects_outcome_of=ModelLifecycle.INFERENCE, ) + mask_threshold = configurable_float( + default_value=0.0, + header="Mask threshold", + description=( + "The threshold to apply to the raw logit output of the model, for each pixel. " + "A higher value means a stricter segmentation prediction." + ), + min_value=0.0, + max_value=1.0, + affects_outcome_of=ModelLifecycle.INFERENCE, + ) + + sim_threshold = configurable_float( + default_value=0.65, + header="Similarity threshold", + description="The threshold to filter point candidates based on similarity scores.", + min_value=0.0, + max_value=1.0, + affects_outcome_of=ModelLifecycle.INFERENCE, + ) + + num_bg_points = configurable_integer( + default_value=1, + header="The number of background points", + description="The number of background points to be used as negative prompts.", + min_value=1, + max_value=1024, + affects_outcome_of=ModelLifecycle.INFERENCE, + ) + @attrs class __POTParameter(BaseConfig.BasePOTParameter): header = string_attribute("POT Parameters") diff --git a/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml b/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml index bd923e0b6b7..097390fba0f 100644 --- a/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml +++ b/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml @@ -15,6 +15,8 @@ dataset: - 57.12 - 57.375 offset_bbox: 0 + generate_point: false + generate_bbox: false model: name: SAM diff --git a/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/template_experimental.yaml b/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/template_experimental.yaml index 63ff5d3d9d4..f5d63392993 100644 --- a/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/template_experimental.yaml +++ b/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/template_experimental.yaml @@ -13,7 +13,7 @@ framework: OTXVisualPrompting v0.1.0 # Task implementations. entrypoints: base: otx.algorithms.visual_prompting.tasks.ZeroShotTask - openvino: otx.algorithms.visual_prompting.tasks.openvino.OpenVINOVisualPromptingTask + openvino: otx.algorithms.visual_prompting.tasks.openvino.OpenVINOZeroShotVisualPromptingTask # Hyper Parameters hyper_parameters: diff --git a/src/otx/algorithms/visual_prompting/tasks/inference.py b/src/otx/algorithms/visual_prompting/tasks/inference.py index ea8a1fbf869..1123cd20c87 100644 --- a/src/otx/algorithms/visual_prompting/tasks/inference.py +++ b/src/otx/algorithms/visual_prompting/tasks/inference.py @@ -19,15 +19,16 @@ import json import os import shutil -import subprocess import tempfile import time import warnings from collections import OrderedDict -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union +import openvino as ov import torch from omegaconf import DictConfig, ListConfig +from openvino.tools import mo from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import TQDMProgressBar from pytorch_lightning.loggers import CSVLogger @@ -284,7 +285,7 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]): height = width = self.config.model.image_size for module, path in onnx_path.items(): if module == "visual_prompting_image_encoder": - dummy_inputs = {"images": torch.randn(1, 3, height, width, dtype=torch.float)} + dummy_inputs = {"images": torch.randn(1, 3, height, width, dtype=torch.float32)} output_names = ["image_embeddings"] dynamic_axes = None model_to_export = self.model.image_encoder @@ -299,11 +300,11 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]): "point_labels": {1: "num_points"}, } dummy_inputs = { - "image_embeddings": torch.zeros(1, embed_dim, *embed_size, dtype=torch.float), - "point_coords": torch.randint(low=0, high=1024, size=(1, 2, 2), dtype=torch.float), - "point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float), - "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), - "has_mask_input": torch.tensor([[1]], dtype=torch.float), + "image_embeddings": torch.zeros(1, embed_dim, *embed_size, dtype=torch.float32), + "point_coords": torch.randint(low=0, high=1024, size=(1, 2, 2), dtype=torch.float32), + "point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float32), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float32), + "has_mask_input": torch.tensor([[1]], dtype=torch.float32), } output_names = ["iou_predictions", "low_res_masks"] model_to_export = self.model @@ -381,25 +382,19 @@ def export( # noqa: D102 output_model.set_data(f"{module}.onnx", file.read()) else: for module, path in onnx_path.items(): - optimize_command = [ - "mo", - "--input_model", - path, - "--output_dir", - self.output_path, - "--model_name", - module, - ] + mo_args: Dict[str, Any] = {"input_model": path} if module == "visual_prompting_image_encoder": - optimize_command += [ - "--mean_values", - str(self.config.dataset.normalize.mean).replace(", ", ","), - "--scale_values", - str(self.config.dataset.normalize.std).replace(", ", ","), - ] + mo_args.update( + { + "mean_values": list(self.config.dataset.normalize.mean), + "scale_values": list(self.config.dataset.normalize.std), + } + ) if precision == ModelPrecision.FP16: - optimize_command.append("--compress_to_fp16") - subprocess.run(optimize_command, check=True) + mo_args.update({"compress_to_fp16": True}) + + ov_model = mo.convert_model(**mo_args) + ov.save_model(ov_model, os.path.join(self.output_path, f"{module}.xml")) with open(path.replace(".onnx", ".bin"), "rb") as file: output_model.set_data(f"{module}.bin", file.read()) with open(path.replace(".onnx", ".xml"), "rb") as file: @@ -547,6 +542,159 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter return inference_callback.otx_dataset + def export( # noqa: D102 + self, + export_type: ExportType, + output_model: ModelEntity, + precision: ModelPrecision = ModelPrecision.FP32, + dump_features: bool = False, + ) -> None: + """Export model to OpenVINO IR. + + When SAM gets an image for inference, image encoder runs just once to get image embedding. + After that, prompt encoder + mask decoder runs repeatedly to get mask prediction. + For this case, SAM should be divided into two parts, image encoder and prompt encoder + mask decoder. + + Args: + export_type (ExportType): Export type should be ExportType.OPENVINO + output_model (ModelEntity): The model entity in which to write the OpenVINO IR data + precision (bool): Output model weights and inference precision + dump_features (bool): Flag to return "feature_vector" and "saliency_map". + + Raises: + Exception: If export_type is not ExportType.OPENVINO + """ + if dump_features: + logger.warning( + "Feature dumping is not implemented for the visual prompting task." + "The saliency maps and representation vector outputs will not be dumped in the exported model." + ) + + self.model = self.load_model(otx_model=self.task_environment.model) + if export_type == ExportType.ONNX: + output_model.model_format = ModelFormat.ONNX + output_model.optimization_type = ModelOptimizationType.ONNX + if precision == ModelPrecision.FP16: + raise RuntimeError("Export to FP16 ONNX is not supported") + elif export_type == ExportType.OPENVINO: + output_model.model_format = ModelFormat.OPENVINO + output_model.optimization_type = ModelOptimizationType.MO + else: + raise RuntimeError(f"not supported export type {export_type}") + + self.precision[0] = precision + output_model.has_xai = dump_features + + logger.info("Exporting to the OpenVINO model.") + onnx_path = { + "visual_prompting_image_encoder": os.path.join(self.output_path, "visual_prompting_image_encoder.onnx"), + "visual_prompting_prompt_getter": os.path.join(self.output_path, "visual_prompting_prompt_getter.onnx"), + "visual_prompting_decoder": os.path.join(self.output_path, "visual_prompting_decoder.onnx"), + } + self._export_to_onnx(onnx_path) + + if export_type == ExportType.ONNX: + for module, path in onnx_path.items(): + with open(path, "rb") as file: + output_model.set_data(f"{module}.onnx", file.read()) + else: + for module, path in onnx_path.items(): + mo_args: Dict[str, Any] = {"input_model": path} + if module == "visual_prompting_image_encoder": + mo_args.update( + { + "mean_values": list(self.config.dataset.normalize.mean), + "scale_values": list(self.config.dataset.normalize.std), + } + ) + if precision == ModelPrecision.FP16: + mo_args.update({"compress_to_fp16": True}) + + ov_model = mo.convert_model(**mo_args) + ov.save_model(ov_model, os.path.join(self.output_path, f"{module}.xml")) + with open(path.replace(".onnx", ".bin"), "rb") as file: + output_model.set_data(f"{module}.bin", file.read()) + with open(path.replace(".onnx", ".xml"), "rb") as file: + output_model.set_data(f"{module}.xml", file.read()) + + output_model.precision = self.precision + output_model.optimization_methods = self.optimization_methods + + output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema)) + self._set_metadata(output_model) + + def _export_to_onnx(self, onnx_path: Dict[str, str]): + """Export model to ONNX. + + Args: + onnx_path (Dict[str, str]): Paths to save ONNX models. + """ + image_size = self.config.model.image_size + embed_dim = self.model.prompt_encoder.embed_dim + embed_size = self.model.prompt_encoder.image_embedding_size + for module, path in onnx_path.items(): + if module == "visual_prompting_image_encoder": + dummy_inputs = {"images": torch.randn(1, 3, image_size, image_size, dtype=torch.float32)} + output_names = ["image_embeddings"] + dynamic_axes = None + model_to_export = self.model.image_encoder + + elif module == "visual_prompting_prompt_getter": + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float32), + "original_size": torch.randint(low=0, high=image_size * 2, size=(1, 2), dtype=torch.int64), + "threshold": torch.tensor([[0.1]], dtype=torch.float32), + "num_bg_points": torch.randint(low=1, high=image_size, size=(1, 1), dtype=torch.int64), + } + output_names = ["total_points_scores", "total_bg_coords"] + dynamic_axes = { + "total_points_scores": {0: "num_labels", 1: "num_points"}, + "total_bg_coords": {0: "num_labels", 1: "num_points"}, + } + model_to_export = self.model.prompt_getter + + elif module == "visual_prompting_decoder": + # sam without backbone + mask_input_size = [4 * x for x in embed_size] + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + dummy_inputs = { + "image_embeddings": torch.zeros(1, embed_dim, *embed_size, dtype=torch.float32), + "point_coords": torch.randint(low=0, high=1024, size=(1, 2, 2), dtype=torch.float32), + "point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float32), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float32), + "has_mask_input": torch.tensor([[1]], dtype=torch.float32), + } + output_names = ["iou_predictions", "low_res_masks"] + model_to_export = self.model + + else: + raise ValueError( + ( + f"{module} is undefined, use visual_prompting_image_encoder, visual_prompting_prompt_getter, " + f"or visual_prompting_decoder." + ) + ) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(path, "wb") as f: + torch.onnx.export( + model_to_export, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=13, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + def save_model(self, output_model: ModelEntity) -> None: """Save the model after training is completed. diff --git a/src/otx/algorithms/visual_prompting/tasks/openvino.py b/src/otx/algorithms/visual_prompting/tasks/openvino.py index fe499300970..dbd385b9e17 100644 --- a/src/otx/algorithms/visual_prompting/tasks/openvino.py +++ b/src/otx/algorithms/visual_prompting/tasks/openvino.py @@ -20,8 +20,9 @@ import random import tempfile import time +from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union from zipfile import ZipFile import attr @@ -119,7 +120,15 @@ def __init__( **attr.asdict( hparams.postprocessing, filter=lambda attr, value: attr.name - not in ["header", "description", "type", "visible_in_ui", "class_name"], + not in [ + "header", + "description", + "type", + "visible_in_ui", + "class_name", + "sim_threshold", + "num_bg_points", + ], ) }, } @@ -159,7 +168,7 @@ def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: """Perform a prediction for a given input image.""" # forward image encoder images, meta, prompts = self.pre_process(dataset_item) - image_embeddings = self.forward(images) + image_embeddings = self.forward_image_encoder(images) annotations: List[Annotation] = [] hard_predictions: List[np.ndarray] = [] @@ -180,7 +189,7 @@ def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: soft_predictions.append(soft_prediction) return annotations - def forward(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def forward_image_encoder(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """Forward function of OpenVINO Visual Prompting Inferencer.""" return self.model["image_encoder"].infer_sync(inputs) @@ -194,6 +203,228 @@ def await_all(self) -> None: self.model["decoder"].await_all() +class OpenVINOZeroShotVisualPromptingInferencer(OpenVINOVisualPromptingInferencer): + """Inferencer implementation for Zero-shot Visual Prompting using OpenVINO backend. + + This inferencer has two models, image encoder and decoder. + + Args: + hparams (VisualPromptingBaseConfig): Hyper parameters that the model should use. + label_schema (LabelSchemaEntity): LabelSchemaEntity that was used during model training. + model_files (Dict[str, Union[str, Path, bytes]]): Path or bytes to model to load, + `.xml`, `.bin` or `.onnx` file. + weight_files (Dict[str, Union[str, Path, bytes, None]], optional): Path or bytes to weights to load, + `.xml`, `.bin` or `.onnx` file. Defaults to None. + device (str): Device to run inference on, such as CPU, GPU or MYRIAD. Defaults to "CPU". + num_requests (int) : Maximum number of requests that the inferencer can make. + Good value is the number of available cores. Defaults to 1. + """ + + def __init__( + self, + hparams: VisualPromptingBaseConfig, + label_schema: LabelSchemaEntity, + model_files: Dict[str, Union[str, Path, bytes]], + weight_files: Optional[Dict[str, Union[str, Path, bytes, None]]] = {}, + device: str = "CPU", + num_requests: int = 1, + ): + + assert all(module in model_files for module in ["image_encoder", "prompt_getter", "decoder"]) + + self.model = {} + model_parameters = { + "prompt_getter": {"input_layouts": "image_embeddings:NCHW"}, + "decoder": {"input_layouts": "image_embeddings:NCHW"}, + } + self.configuration = { + "image_encoder": { + **attr.asdict(hparams.postprocessing, filter=lambda attr, value: attr.name in ["image_size"]) + }, + "prompt_getter": { + **attr.asdict( + hparams.postprocessing, + filter=lambda attr, value: attr.name + in ["image_size", "sim_threshold", "num_bg_points", "embedded_processing"], + ) + }, + "decoder": { + **attr.asdict( + hparams.postprocessing, + filter=lambda attr, value: attr.name + not in [ + "header", + "description", + "type", + "visible_in_ui", + "class_name", + "sim_threshold", + "num_bg_points", + ], + ) + }, + } + + core = create_core() + for name in ["image_encoder", "prompt_getter", "decoder"]: + model_adapter = OpenvinoAdapter( + core=core, + model=model_files.get(name), + weights_path=weight_files.get(name, None), + model_parameters=model_parameters.get(name, {}), + device=device, + max_num_requests=num_requests, + plugin_config={"PERFORMANCE_HINT": "THROUGHPUT"}, + ) + self.model[name] = Model.create_model(model_adapter, name, self.configuration.get(name, {}), preload=True) + self.converter = VisualPromptingToAnnotationConverter() + self.labels = label_schema.get_labels(include_empty=False) + self.transform = get_transform() # TODO (sungchul): insert args + + self.point_labels_box = np.array([[2, 3]], dtype=np.float32) + self.has_mask_inputs = [np.array([[0.0]]), np.array([[1.0]])] + + def pre_process( # type: ignore + self, dataset_item: DatasetItemEntity, extra_processing: bool = False + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Pre-process function of OpenVINO Zero-shot Visual Prompting Inferencer for image encoder.""" + return self.model["image_encoder"].preprocess(dataset_item.numpy, extra_processing) + + def pre_process_prompt_getter( + self, image_embeddings: Dict[str, np.ndarray], original_size: np.ndarray + ) -> Dict[str, np.ndarray]: + """Pre-process function of OpenVINO Zero-shot VIsual Prompting Inferencer for prompt getter.""" + inputs_prompt_getter = { + "original_size": original_size[None], + "threshold": np.array([[self.model["prompt_getter"].sim_threshold]], dtype=np.float32), + "num_bg_points": np.array([[self.model["prompt_getter"].num_bg_points]], dtype=np.int64), + } + inputs_prompt_getter.update(image_embeddings) + return inputs_prompt_getter + + def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: ignore + """Perform a prediction for a given input image.""" + # forward image encoder + images, meta = self.pre_process(dataset_item) + original_size = np.array(meta["original_shape"][:2], dtype=np.int64) + image_embeddings = self.forward_image_encoder(images) + + # get point candidates + inputs_prompt_getter = self.pre_process_prompt_getter(image_embeddings, original_size) + total_prompts = self.forward_prompt_getter(inputs_prompt_getter) + + annotations: List[Annotation] = [] + predicted_masks: DefaultDict = defaultdict(list) + for label, (points_scores, bg_coords) in enumerate( + zip(total_prompts["total_points_scores"], total_prompts["total_bg_coords"]) + ): + for points_score in points_scores: + if points_score[-1] == -1: + continue + x, y = points_score[:2] + is_done = False + for pm in predicted_masks.get(label, []): + # check if that point is already assigned + if pm[int(y), int(x)] > 0: + is_done = True + break + if is_done: + continue + + point_coords = np.concatenate((np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32) + point_coords = self.model["decoder"]._apply_coords(point_coords, original_size) + point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) + inputs_decoder = {"point_coords": point_coords[None], "point_labels": point_labels[None]} + inputs_decoder.update(image_embeddings) + + prediction = self.forward_decoder(inputs_decoder, original_size) + metadata = { + "label": [_label for _label in self.labels if int(_label.id_) == label][0], + "original_size": original_size[None], + } + + # set annotation for eval + annotation, hard_prediction, soft_prediction = self.post_process(prediction, metadata) + annotations.extend(annotation) + predicted_masks[label].append(hard_prediction) + return annotations + + def forward_prompt_getter(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Forward function of OpenVINO Visual Prompting Inferencer.""" + return self.model["prompt_getter"].infer_sync(inputs) + + def forward_decoder( # type: ignore + self, inputs: Dict[str, np.ndarray], original_size: np.ndarray + ) -> Dict[str, np.ndarray]: + """Forward function of OpenVINO Visual Prompting Inferencer.""" + logits: np.ndarray + scores: np.ndarray + mask_slice = slice(0, 1) + for i in range(3): + if i == 0: + # First-step prediction + mask_input = np.zeros( + (1, 1, *map(lambda x: x * 4, inputs["image_embeddings"].shape[2:])), dtype=np.float32 + ) + has_mask_input = self.has_mask_inputs[0] + + elif i == 1: + # Cascaded Post-refinement-1 + mask_input, masks, iou_predictions = self._postprocess_masks( + logits, scores, original_size # noqa: F821 + ) + if masks.sum() == 0: + return {"iou_predictions": iou_predictions, "low_res_masks": mask_input} + + has_mask_input = self.has_mask_inputs[1] + + elif i == 2: + # Cascaded Post-refinement-2 + mask_input, masks, iou_predictions = self._postprocess_masks( + logits, scores, original_size # noqa: F821 + ) + if masks.sum() == 0: + return {"iou_predictions": iou_predictions, "low_res_masks": mask_input} + + has_mask_input = self.has_mask_inputs[1] + y, x = np.nonzero(masks) + inputs["point_coords"] = np.concatenate( + (inputs["point_coords"], np.array([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=np.float32)), + axis=1, + ) + inputs["point_labels"] = np.concatenate((inputs["point_labels"], self.point_labels_box), axis=1) + + inputs.update({"mask_input": mask_input, "has_mask_input": has_mask_input}) + prediction = self.model["decoder"].infer_sync(inputs) + scores, logits = prediction["iou_predictions"], prediction["low_res_masks"] + + return {"iou_predictions": scores[:, mask_slice], "low_res_masks": logits[:, mask_slice, :, :]} + + def _postprocess_masks( + self, logits: np.ndarray, scores: np.ndarray, original_size: np.ndarray + ) -> Tuple[np.ndarray, ...]: + """Post-process logits for resized masks according to best index based on scores.""" + high_res_masks = self.model["decoder"].resize_and_crop(logits[0].transpose(1, 2, 0), original_size) + masks = high_res_masks > self.model["decoder"].mask_threshold + masks = masks.transpose(2, 0, 1)[None] + + # skip the first index components + scores, masks, logits = map(lambda x: x[:, 1:], (scores, masks, logits)) + + # filter zero masks + while len(scores[0]) > 0 and masks[0, (best_idx := np.argmax(scores[0]))].sum() == 0: + scores, masks, logits = map( + lambda x: np.concatenate((x[:, :best_idx], x[:, best_idx + 1 :]), axis=1), (scores, masks, logits) + ) + + if len(scores[0]) == 0: + # all predicted masks were zero masks, ignore them. + return None, np.zeros((self.model["decoder"].image_size, self.model["decoder"].image_size)), 0.0 + + best_idx = np.argmax(scores[0]) + return logits[:, [best_idx]], masks[0, best_idx], scores[0, best_idx] + + class OTXOpenVinoDataLoader: """DataLoader implementation for VisualPromptingOpenVINOTask.""" @@ -484,3 +715,27 @@ def optimize( if optimization_parameters is not None: optimization_parameters.update_progress(100, None) logger.info("PTQ optimization completed") + + +class OpenVINOZeroShotVisualPromptingTask(OpenVINOVisualPromptingTask): + """Task implementation for Zero-shot Visual Prompting using OpenVINO backend.""" + + def load_inferencer(self) -> OpenVINOZeroShotVisualPromptingInferencer: + """Load OpenVINO Zero-shot Visual Prompting Inferencer.""" + if self.model is None: + raise RuntimeError("load_inferencer failed, model is None") + return OpenVINOZeroShotVisualPromptingInferencer( + self.hparams, + self.task_environment.label_schema, + model_files={ + "image_encoder": self.model.get_data("visual_prompting_image_encoder.xml"), + "prompt_getter": self.model.get_data("visual_prompting_prompt_getter.xml"), + "decoder": self.model.get_data("visual_prompting_decoder.xml"), + }, + weight_files={ + "image_encoder": self.model.get_data("visual_prompting_image_encoder.bin"), + "prompt_getter": self.model.get_data("visual_prompting_prompt_getter.bin"), + "decoder": self.model.get_data("visual_prompting_decoder.bin"), + }, + num_requests=get_default_async_reqs_num(), + ) diff --git a/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py b/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py index 40d1f4beec2..73f8b95be0e 100644 --- a/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py +++ b/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py @@ -466,7 +466,7 @@ def convert_to_annotation(self, hard_prediction: np.ndarray, metadata: Dict[str, annotations = create_annotation_from_segmentation_map( hard_prediction=hard_prediction, soft_prediction=soft_prediction, - label_map={1: metadata["label"].label}, + label_map={1: metadata["label"].label if isinstance(metadata["label"], ScoredLabel) else metadata["label"]}, ) return annotations diff --git a/src/otx/cli/utils/io.py b/src/otx/cli/utils/io.py index 3770fb279bf..e747a93b42b 100644 --- a/src/otx/cli/utils/io.py +++ b/src/otx/cli/utils/io.py @@ -49,6 +49,8 @@ "tile_classifier.bin", "visual_prompting_image_encoder.xml", "visual_prompting_image_encoder.bin", + "visual_prompting_prompt_getter.xml", + "visual_prompting_prompt_getter.bin", "visual_prompting_decoder.xml", "visual_prompting_decoder.bin", "image_threshold", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded. diff --git a/tests/e2e/cli/visual_prompting/test_visual_prompting.py b/tests/e2e/cli/visual_prompting/test_visual_prompting.py index 2749a09f347..b6c1190e1d0 100644 --- a/tests/e2e/cli/visual_prompting/test_visual_prompting.py +++ b/tests/e2e/cli/visual_prompting/test_visual_prompting.py @@ -122,7 +122,15 @@ def test_otx_export_fp16(self, template, tmp_dir_path): @pytest.mark.parametrize("half_precision", [True, False]) def test_otx_eval_openvino(self, template, tmp_dir_path, half_precision): tmp_dir_path = tmp_dir_path / "visual_prompting" - otx_eval_openvino_testing(template, tmp_dir_path, otx_dir, args, threshold=0.2, half_precision=half_precision) + otx_eval_openvino_testing( + template, + tmp_dir_path, + otx_dir, + args, + threshold=0.2, + half_precision=half_precision, + is_visual_prompting=True, + ) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @@ -143,4 +151,4 @@ def test_ptq_validate_fq(self, template, tmp_dir_path): @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_ptq_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "visual_prompting" - ptq_eval_testing(template, tmp_dir_path, otx_dir, args) + ptq_eval_testing(template, tmp_dir_path, otx_dir, args, is_visual_prompting=True) diff --git a/tests/e2e/cli/visual_prompting/test_zero_shot.py b/tests/e2e/cli/visual_prompting/test_zero_shot.py new file mode 100644 index 00000000000..8a73ae304bd --- /dev/null +++ b/tests/e2e/cli/visual_prompting/test_zero_shot.py @@ -0,0 +1,130 @@ +"""Tests for Visual Prompting with OTX CLI""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import copy +import os + +import pytest + +from otx.api.entities.model_template import parse_model_template +from otx.cli.registry import Registry +from tests.test_suite.e2e_test_system import e2e_pytest_component +from tests.test_suite.run_test_command import ( + otx_eval_openvino_testing, + otx_eval_testing, + otx_export_testing, + otx_train_testing, + ptq_optimize_testing, + ptq_validate_fq_testing, + ptq_eval_testing, +) + +args = { + "--train-data-roots": "tests/assets/car_tree_bug", + "--val-data-roots": "tests/assets/car_tree_bug", + "--test-data-roots": "tests/assets/car_tree_bug", + "--input": "tests/assets/car_tree_bug/images/train", + "train_params": [ + "params", + "--learning_parameters.trainer.max_epochs", + "1", + "--learning_parameters.dataset.train_batch_size", + "1", + "--learning_parameters.dataset.use_mask", + "False", + ], +} + +otx_dir = os.getcwd() + +TT_STABILITY_TESTS = os.environ.get("TT_STABILITY_TESTS", False) +if TT_STABILITY_TESTS: + default_template = parse_model_template( + os.path.join( + "src/otx/algorithms/visual_prompting/configs", "zero_shot_sam_tiny_vit", "template_experimental.yaml" + ) + ) + templates = [default_template] * 100 + templates_ids = [template.model_template_id + f"-{i+1}" for i, template in enumerate(templates)] + +else: + templates = [ + template + for template in Registry("src/otx/algorithms/visual_prompting", experimental=True) + .filter(task_type="VISUAL_PROMPTING") + .templates + if "Zero_Shot" in template.name + ] + templates_ids = [template.model_template_id for template in templates] + + +class TestToolsZeroShotVisualPrompting: + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_train(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_train_testing(template, tmp_dir_path, otx_dir, args, deterministic=True) + + @e2e_pytest_component + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_eval(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_eval_testing(template, tmp_dir_path, otx_dir, args) + + @e2e_pytest_component + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_export(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_export_testing(template, tmp_dir_path, False) + + @e2e_pytest_component + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_export_fp16(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_export_testing(template, tmp_dir_path, half_precision=True) + + @e2e_pytest_component + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("half_precision", [True, False]) + def test_otx_eval_openvino(self, template, tmp_dir_path, half_precision): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_eval_openvino_testing( + template, + tmp_dir_path, + otx_dir, + args, + threshold=0.2, + half_precision=half_precision, + is_visual_prompting=True, + ) + + @e2e_pytest_component + @pytest.mark.skip(reason="optimize for zsl is not supported yet.") + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_ptq_optimize(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + ptq_optimize_testing(template, tmp_dir_path, otx_dir, args, is_visual_prompting=True) + + @e2e_pytest_component + @pytest.mark.skip(reason="optimize for zsl is not supported yet.") + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_ptq_validate_fq(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + ptq_validate_fq_testing(template, tmp_dir_path, otx_dir, "zero_shot_visual_prompting", type(self).__name__) + + @e2e_pytest_component + @pytest.mark.skip(reason="optimize for zsl is not supported yet.") + @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_ptq_eval(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + ptq_eval_testing(template, tmp_dir_path, otx_dir, args) diff --git a/tests/integration/cli/visual_prompting/test_visual_prompting.py b/tests/integration/cli/visual_prompting/test_visual_prompting.py index 18d220376a1..92ff4bf356e 100644 --- a/tests/integration/cli/visual_prompting/test_visual_prompting.py +++ b/tests/integration/cli/visual_prompting/test_visual_prompting.py @@ -109,7 +109,15 @@ def test_otx_export_onnx(self, template, tmp_dir_path): @pytest.mark.parametrize("half_precision", [True, False]) def test_otx_eval_openvino(self, template, tmp_dir_path, half_precision): tmp_dir_path = tmp_dir_path / "visual_prompting" - otx_eval_openvino_testing(template, tmp_dir_path, otx_dir, args, threshold=1.0, half_precision=half_precision) + otx_eval_openvino_testing( + template, + tmp_dir_path, + otx_dir, + args, + threshold=1.0, + half_precision=half_precision, + is_visual_prompting=True, + ) @e2e_pytest_component @pytest.mark.skip("demo.py is not supported.") diff --git a/tests/integration/cli/visual_prompting/test_zero_shot.py b/tests/integration/cli/visual_prompting/test_zero_shot.py index 8d403f27999..ccedf5c2fa2 100644 --- a/tests/integration/cli/visual_prompting/test_zero_shot.py +++ b/tests/integration/cli/visual_prompting/test_zero_shot.py @@ -12,6 +12,8 @@ from tests.test_suite.run_test_command import ( otx_eval_testing, otx_train_testing, + otx_export_testing, + otx_eval_openvino_testing, ) args = { @@ -39,7 +41,7 @@ templates_ids = [template.model_template_id for template in templates] -class TestVisualPromptingCLI: +class TestZeroShotVisualPromptingCLI: @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_train(self, template, tmp_dir_path): @@ -51,3 +53,36 @@ def test_otx_train(self, template, tmp_dir_path): def test_otx_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" otx_eval_testing(template, tmp_dir_path, otx_dir, args) + + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_export(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_export_testing(template, tmp_dir_path, False, check_ir_meta=False) + + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_export_fp16(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_export_testing(template, tmp_dir_path, half_precision=True) + + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + def test_otx_export_onnx(self, template, tmp_dir_path): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_export_testing(template, tmp_dir_path, half_precision=False, is_onnx=True) + + @e2e_pytest_component + @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("half_precision", [True, False]) + def test_otx_eval_openvino(self, template, tmp_dir_path, half_precision): + tmp_dir_path = tmp_dir_path / "zero_shot_visual_prompting" + otx_eval_openvino_testing( + template, + tmp_dir_path, + otx_dir, + args, + threshold=1.0, + half_precision=half_precision, + is_visual_prompting=True, + ) diff --git a/tests/test_suite/run_test_command.py b/tests/test_suite/run_test_command.py index 2375467399c..597f5d48323 100644 --- a/tests/test_suite/run_test_command.py +++ b/tests/test_suite/run_test_command.py @@ -247,12 +247,15 @@ def otx_export_testing(template, root, dump_features=False, half_precision=False path_to_xml = os.path.join(save_path, "openvino.xml") assert os.path.exists(os.path.join(save_path, "label_schema.json")) if not is_onnx: - if "Visual_Prompting" in template.model_template_id: + if any(map(lambda x: x in template.model_template_id, ("Visual_Prompting", "Zero_Shot"))): path_to_xml = os.path.join(save_path, "visual_prompting_decoder.xml") assert os.path.exists(os.path.join(save_path, "visual_prompting_image_encoder.xml")) assert os.path.exists(os.path.join(save_path, "visual_prompting_image_encoder.bin")) assert os.path.exists(os.path.join(save_path, "visual_prompting_decoder.xml")) assert os.path.exists(os.path.join(save_path, "visual_prompting_decoder.bin")) + if "Zero_Shot" in template.model_template_id: + assert os.path.exists(os.path.join(save_path, "visual_prompting_prompt_getter.xml")) + assert os.path.exists(os.path.join(save_path, "visual_prompting_prompt_getter.bin")) else: assert os.path.exists(path_to_xml) assert os.path.exists(os.path.join(save_path, "openvino.bin")) @@ -263,9 +266,11 @@ def otx_export_testing(template, root, dump_features=False, half_precision=False xml_model = xml_stream.read() assert f"{input_size[1]},{input_size[0]}" in xml_model else: - if "Visual_Prompting" in template.model_template_id: + if any(map(lambda x: x in template.model_template_id, ("Visual_Prompting", "Zero_Shot"))): assert os.path.exists(os.path.join(save_path, "visual_prompting_image_encoder.onnx")) assert os.path.exists(os.path.join(save_path, "visual_prompting_decoder.onnx")) + if "Zero_Shot" in template.model_template_id: + assert os.path.exists(os.path.join(save_path, "visual_prompting_prompt_getter.onnx")) else: path_to_onnx = os.path.join(save_path, "model.onnx") assert os.path.exists(path_to_onnx) @@ -334,14 +339,16 @@ def otx_eval_openvino_testing( args, threshold=0.0, half_precision=False, + is_visual_prompting=False, ): template_work_dir = get_template_dir(template, root) - weights_path = f"{template_work_dir}/exported_{template.model_template_id}/openvino.xml" + weights_file = "visual_prompting_decoder" if is_visual_prompting else "openvino" + weights_path = f"{template_work_dir}/exported_{template.model_template_id}/{weights_file}.xml" output_path = f"{template_work_dir}/exported_{template.model_template_id}" perf_path = f"{template_work_dir}/exported_{template.model_template_id}/performance.json" if half_precision: - weights_path = f"{template_work_dir}/exported_{template.model_template_id}_fp16/openvino.xml" + weights_path = f"{template_work_dir}/exported_{template.model_template_id}_fp16/{weights_file}.xml" output_path = f"{template_work_dir}/exported_{template.model_template_id}_fp16" perf_path = f"{template_work_dir}/exported_{template.model_template_id}_fp16/performance.json" diff --git a/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py b/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py index 716df9c70b4..7740de14ab9 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py +++ b/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py @@ -15,6 +15,7 @@ from otx.algorithms.visual_prompting.adapters.openvino.model_wrappers import ( Decoder, ImageEncoder, + PromptGetter, ) from otx.api.entities.label import LabelEntity from tests.test_suite.e2e_test_system import e2e_pytest_unit @@ -47,6 +48,16 @@ def test_preproces(self, mocker): assert meta["resize_type"] == "fit_to_window" +class TestPromptGetter: + @e2e_pytest_unit + def test_parameters(self): + """Test parameters.""" + params = PromptGetter.parameters() + + assert params.get("sim_threshold").default_value == 0.5 + assert params.get("num_bg_points").default_value == 1 + + class TestDecoder: @pytest.fixture(autouse=True) def setup(self, mocker): diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py index 96c17dd2e35..36225af9d4a 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py @@ -30,7 +30,7 @@ def test_collate_fn(): "bboxes": np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), "points": [], "gt_masks": [Tensor([1, 2, 3])], - "original_size": [], + "original_size": np.array([1, 3]), "padding": [], "path": [], "labels": [], @@ -41,7 +41,7 @@ def test_collate_fn(): "bboxes": np.array([[9, 10, 11, 12]]), "points": [], "gt_masks": [Tensor([4, 5, 6])], - "original_size": [], + "original_size": np.array([1, 3]), "padding": [], "path": [], "labels": [], @@ -53,7 +53,7 @@ def test_collate_fn(): "bboxes": [Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), Tensor([[9, 10, 11, 12]])], "points": None, "gt_masks": [Tensor([[1, 2, 3]]), Tensor([[4, 5, 6]])], - "original_size": [[], []], + "original_size": [Tensor([1, 3]), Tensor([1, 3])], "path": [[], []], "labels": [[], []], "padding": [[], []], @@ -69,7 +69,8 @@ def test_collate_fn(): assert len(results["gt_masks"]) == len(expected["gt_masks"]) for r, e in zip(results["gt_masks"], expected["gt_masks"]): assert torch.all(r == e) - assert results["original_size"] == expected["original_size"] + for r, e in zip(results["original_size"], expected["original_size"]): + assert torch.all(r == e) assert results["path"] == expected["path"] assert results["labels"] == expected["labels"] assert results["padding"] == expected["padding"] diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py index c3701eb3f58..99a76c3b17b 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py @@ -19,7 +19,7 @@ generate_bbox, generate_bbox_from_mask, get_transform, - # generate_point_from_mask, + generate_point_from_mask, ) from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import ( MultipleInputsCompose, @@ -184,7 +184,7 @@ def test_getitem( # Check specific values in the item assert item["index"] == 0 assert (item["images"] == dataset[0].media.numpy).all() - assert item["original_size"] == dataset[0].media.numpy.shape[:2] + assert np.all(item["original_size"] == dataset[0].media.numpy.shape[:2]) assert item["path"] == dataset[0].media.path assert isinstance(item["gt_masks"], list) assert isinstance(item["gt_masks"][0], np.ndarray) @@ -220,7 +220,7 @@ def test_getitem( # Check specific values in the item assert item["index"] == 0 assert (item["images"] == dataset[0].media.numpy).all() - assert item["original_size"] == dataset[0].media.numpy.shape[:2] + assert np.all(item["original_size"] == dataset[0].media.numpy.shape[:2]) assert item["path"] == dataset[0].media.path assert isinstance(item["gt_masks"], list) assert isinstance(item["gt_masks"][0], np.ndarray) @@ -248,8 +248,8 @@ def test_init_zeroshot(self, set_datamodule): datamodule = set_datamodule(train_type=TrainType.Zeroshot) assert datamodule.config.get("train_batch_size") == 1 - # assert "generate_point" in datamodule.kwargs - # assert "generate_bbox" in datamodule.kwargs + assert "generate_point" in datamodule.kwargs + assert "generate_bbox" in datamodule.kwargs @e2e_pytest_unit def test_setup(self, mocker, set_datamodule) -> None: diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py index 799d06f846b..fed22e060c8 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py @@ -349,29 +349,13 @@ def test_select_masks(self) -> None: @e2e_pytest_unit def test_mask_postprocessing(self, mocker) -> None: """Test mask_postprocessing.""" - sam = SegmentAnything(config=self.base_config) - mocker.patch.object(sam, "resize_longest_image_size", return_value=Tensor((6, 6))) - sam.config.image_size = 6 - masks = torch.empty(1, 1, 2, 2) orig_size = Tensor((8, 8)) - results = sam.mask_postprocessing(masks, orig_size) + results = SegmentAnything.mask_postprocessing(masks, 6, orig_size) assert results[0, 0].shape == tuple(orig_size) - @e2e_pytest_unit - def test_resize_longest_image_size(self) -> None: - """Test resize_longest_image_size.""" - sam = SegmentAnything(config=self.base_config) - - input_image_size = Tensor((2, 4)) - longest_side = 6 - - results = sam.resize_longest_image_size(input_image_size, longest_side) - - assert torch.all(results == Tensor((3, 6))) - @e2e_pytest_unit def test_forward_train(self) -> None: """Test forward.""" diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py index b4ac5343147..aa612a432fe 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py @@ -11,17 +11,25 @@ import torch from omegaconf import DictConfig +from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything import ( + SegmentAnything, +) from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything import ( PromptGetter, ZeroShotSegmentAnything, ) -from tests.unit.algorithms.visual_prompting.test_helpers import MockScoredLabel, MockImageEncoder, MockPromptGetter +from tests.unit.algorithms.visual_prompting.test_helpers import ( + MockScoredLabel, + MockImageEncoder, + MockPromptGetter, + MockMaskDecoder, +) class TestPromptGetter: @pytest.fixture(autouse=True) def setup(self) -> None: - self.prompt_getter = PromptGetter(image_size=3) + self.prompt_getter = PromptGetter(image_size=3, downsizing=1) @e2e_pytest_unit def test_initialize(self) -> None: @@ -46,47 +54,64 @@ def test_set_reference(self) -> None: self.prompt_getter.set_reference( label=MockScoredLabel(label=1), reference_feats=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), - reference_prompts=torch.zeros((self.prompt_getter.image_size, self.prompt_getter.image_size)), + reference_prompts=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), ) + assert self.prompt_getter.reference_feats[0].sum() == 0 + assert self.prompt_getter.reference_prompts[0].sum() == 0 assert self.prompt_getter.reference_feats[1].sum() == 9 - assert self.prompt_getter.reference_prompts[1].sum() == 0 + assert self.prompt_getter.reference_prompts[1].sum() == 9 + + self.prompt_getter.set_reference( + label=MockScoredLabel(label=3), + reference_feats=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), + reference_prompts=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), + ) + + assert self.prompt_getter.reference_feats[2].sum() == 0 + assert self.prompt_getter.reference_prompts[2].sum() == 0 + assert self.prompt_getter.reference_feats[3].sum() == 9 + assert self.prompt_getter.reference_prompts[3].sum() == 9 @e2e_pytest_unit def test_forward(self, mocker) -> None: """Test forward.""" - mocker.patch( - "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.ZeroShotSegmentAnything" + mocker.patch.object( + self.prompt_getter, + "get_prompt_candidates", + return_value=(torch.tensor([[[0, 0, 0.5], [1, 1, 0.7]]]), torch.tensor([[[2, 2]]])), ) - mocker.patch.object(self.prompt_getter, "_point_selection", return_value=("points_scores", "bg_coords")) + image_embeddings = torch.ones(1, 4, 4, 4) + self.prompt_getter.reference_feats = torch.rand(1, 1, 4) + original_size = torch.tensor((self.prompt_getter.image_size, self.prompt_getter.image_size), dtype=torch.int64) - image_embeddings = torch.rand(1, 2, self.prompt_getter.image_size, self.prompt_getter.image_size) - self.prompt_getter.reference_feats = {1: torch.rand(1, 2)} - - prompts = self.prompt_getter( - image_embeddings=image_embeddings, - padding=(0, 0, 0, 0), - original_size=(self.prompt_getter.image_size, self.prompt_getter.image_size), + total_points_scores, total_bg_coords = self.prompt_getter( + image_embeddings=image_embeddings, original_size=original_size ) - assert 1 in prompts - assert prompts[1] == ("points_scores", "bg_coords") + assert total_points_scores.shape[0] == 1 + assert total_bg_coords.shape[0] == 1 @e2e_pytest_unit - def test_preprocess_target_feat(self) -> None: - """Test _preprocess_target_feat.""" - old_target_feat = torch.arange(1, self.prompt_getter.image_size**2 + 1, dtype=torch.float).reshape( - 1, 1, self.prompt_getter.image_size, self.prompt_getter.image_size + def test_get_prompt_candidates(self, mocker) -> None: + """Test get_prompt_candidates.""" + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.ZeroShotSegmentAnything" + ) + mocker.patch.object(self.prompt_getter, "_point_selection", return_value=("points_scores", "bg_coords")) + image_embeddings = torch.ones(1, 4, 4, 4) + self.prompt_getter.reference_feats = torch.rand(1, 1, 4) + label = torch.tensor([[0]], dtype=torch.int64) + original_size = torch.tensor( + [[self.prompt_getter.image_size, self.prompt_getter.image_size]], dtype=torch.int64 ) - new_target_feat = self.prompt_getter._preprocess_target_feat( - target_feat=old_target_feat, - c_feat=1, - h_feat=self.prompt_getter.image_size, - w_feat=self.prompt_getter.image_size, + + points_scores, bg_coords = self.prompt_getter.get_prompt_candidates( + image_embeddings=image_embeddings, label=label, original_size=original_size ) - assert new_target_feat.sum() == 9 - assert new_target_feat.shape == (1, self.prompt_getter.image_size**2) + assert points_scores == "points_scores" + assert bg_coords == "bg_coords" @e2e_pytest_unit def test_point_selection(self) -> None: @@ -95,9 +120,8 @@ def test_point_selection(self) -> None: points_scores, bg_coords = self.prompt_getter._point_selection( mask_sim=mask_sim, - original_size=(self.prompt_getter.image_size, self.prompt_getter.image_size), - threshold=0.5, - downsizing=1, + original_size=torch.tensor([self.prompt_getter.image_size, self.prompt_getter.image_size]), + threshold=torch.tensor([[0.5]]), ) assert torch.equal(points_scores, torch.tensor([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]])) @@ -112,6 +136,10 @@ def zero_shot_segment_anything(state_dict: Optional[OrderedDict] = None): "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMImageEncoder", MockImageEncoder, ) + monkeypatch.setattr( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMMaskDecoder", + MockMaskDecoder, + ) return ZeroShotSegmentAnything(state_dict=state_dict) return zero_shot_segment_anything @@ -164,12 +192,8 @@ def test_learn(self, mocker, set_zero_shot_segment_anything) -> None: zero_shot_segment_anything = set_zero_shot_segment_anything() mocker.patch.object( zero_shot_segment_anything, - "_predict_mask", - return_value=( - torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - torch.tensor([1, 0, 0]), - torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - ), + "_predict_masks", + return_value=torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), ) processed_prompts = {MockScoredLabel(label=1, name="label"): [{"box": torch.tensor([[0, 0, 1, 1]])}]} @@ -180,13 +204,11 @@ def test_learn(self, mocker, set_zero_shot_segment_anything) -> None: original_size=(8, 8), ) - assert zero_shot_segment_anything.prompt_getter.reference_feats.get(1).shape == (1, 2) - assert zero_shot_segment_anything.prompt_getter.reference_prompts.get(1).shape == (8, 8) + assert zero_shot_segment_anything.prompt_getter.reference_feats.shape == (2, 1, 2) + assert zero_shot_segment_anything.prompt_getter.reference_prompts.shape == (2, 8, 8) @e2e_pytest_unit - @pytest.mark.parametrize( - "expected", [[torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]]), torch.tensor([0.0, 0.0, 0.5])]] - ) + @pytest.mark.parametrize("expected", [[torch.ones((8, 8)), torch.tensor([0.0, 0.0, 0.5])]]) def test_infer(self, monkeypatch, mocker, set_zero_shot_segment_anything, expected: torch.Tensor) -> None: """Test infer.""" monkeypatch.setattr( @@ -195,26 +217,39 @@ def test_infer(self, monkeypatch, mocker, set_zero_shot_segment_anything, expect ) zero_shot_segment_anything = set_zero_shot_segment_anything() - zero_shot_segment_anything.prompt_getter.reference_feats = {1: torch.rand((1, 2))} - zero_shot_segment_anything.prompt_getter.reference_prompts = {1: torch.zeros((8, 8))} + zero_shot_segment_anything.prompt_getter.reference_feats = torch.rand(1, 1, 4) + zero_shot_segment_anything.prompt_getter.reference_prompts = torch.zeros((8, 8)) mocker.patch.object( - zero_shot_segment_anything, - "_predict_mask", - return_value=( - torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - torch.tensor([1, 0, 0]), - torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - ), + SegmentAnything, "forward", return_value=(torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)) ) total_results = zero_shot_segment_anything.infer( - images=torch.ones((1, 3, 8, 8)), padding=(0, 0, 0, 0), original_size=(8, 8) + images=torch.ones((1, 3, 8, 8)), original_size=torch.tensor([[8, 8]], dtype=torch.int64) ) for i, results in enumerate(total_results[0]): for _, result in results.items(): assert torch.equal(result[0], expected[i]) + @e2e_pytest_unit + @pytest.mark.parametrize("is_postprocess", [True, False]) + def test_predict_masks(self, mocker, set_zero_shot_segment_anything, is_postprocess: bool) -> None: + """Test _predict_masks.""" + mocker.patch.object( + SegmentAnything, "forward", return_value=(torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)) + ) + + zero_shot_segment_anything = set_zero_shot_segment_anything() + zero_shot_segment_anything.config.model.image_size = 6 + + mask = zero_shot_segment_anything._predict_masks( + image_embeddings=torch.rand(1), + point_coords=torch.rand(1, 2, 2), + point_labels=torch.randint(low=0, high=2, size=(1, 2)), + original_size=torch.tensor((8, 8), dtype=torch.int64), + ) + assert mask.shape == (8, 8) + @e2e_pytest_unit def test_preprocess_prompts(self, set_zero_shot_segment_anything) -> None: """Test _preprocess_prompts. @@ -248,18 +283,39 @@ def test_generate_masked_features(self, set_zero_shot_segment_anything) -> None: assert masked_feat.shape == (1, 1) @e2e_pytest_unit - def test_preprocess_mask(self, set_zero_shot_segment_anything) -> None: - """Test _preprocess_mask.""" + def test_preprocess_masks(self, set_zero_shot_segment_anything) -> None: + """Test _preprocess_masks.""" zero_shot_segment_anything = set_zero_shot_segment_anything() zero_shot_segment_anything.config.model.image_size = 16 - result = zero_shot_segment_anything._preprocess_mask(x=torch.ones(1, 1, 8, 8)) + result = zero_shot_segment_anything._preprocess_masks(x=torch.ones(1, 1, 8, 8)) assert result[:8, :8].sum() == 8**2 assert result[:8, 8:].sum() == 0 assert result[8:, :8].sum() == 0 assert result[8:, 8:].sum() == 0 + @e2e_pytest_unit + @pytest.mark.parametrize( + "logits,expected", + [ + (torch.ones(1, 4, 4, 4), torch.ones(4, 4, dtype=torch.bool)), + (torch.zeros(1, 4, 4, 4), torch.zeros(4, 4, dtype=torch.bool)), + ], + ) + def test_postprocess_masks( + self, set_zero_shot_segment_anything, logits: torch.Tensor, expected: torch.Tensor + ) -> None: + """Test _postprocess_masks.""" + zero_shot_segment_anything = set_zero_shot_segment_anything() + zero_shot_segment_anything.config.model.image_size = 4 + scores = torch.tensor([[0.0, 0.1, 0.2, 0.3]]) + original_size = torch.tensor([4, 4], dtype=torch.int64) + + _, result = zero_shot_segment_anything._postprocess_masks(logits, scores, original_size) + + assert torch.equal(result, expected) + @e2e_pytest_unit @pytest.mark.parametrize("use_only_background", [True, False]) def test_merge_prompts(self, set_zero_shot_segment_anything, use_only_background: bool) -> None: @@ -285,37 +341,3 @@ def test_merge_prompts(self, set_zero_shot_segment_anything, use_only_background else: assert torch.equal(merged_input_prompts.get("point_coords"), torch.tensor([1, 0, 2])) assert torch.equal(merged_input_prompts.get("point_labels"), torch.tensor([1, 0, 0])) - - @e2e_pytest_unit - def test_predict_target_mask(self, mocker, set_zero_shot_segment_anything) -> None: - """Test _predict_target_mask.""" - zero_shot_segment_anything = set_zero_shot_segment_anything() - mocker.patch.object( - zero_shot_segment_anything, - "_predict_mask", - return_value=( - torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - torch.tensor([1, 0, 0]), - torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - ), - ) - - mask = zero_shot_segment_anything._predict_target_mask( - image_embeddings=torch.rand(1), input_prompts={}, padding=(0, 0, 0, 0), original_size=(1, 1) - ) - - assert mask.shape == (3, 3) - - @e2e_pytest_unit - def test_predict_mask(self, mocker, set_zero_shot_segment_anything) -> None: - """Test _predict_mask.""" - zero_shot_segment_anything = set_zero_shot_segment_anything() - mocker.patch.object(zero_shot_segment_anything, "postprocess_masks", return_value=torch.Tensor([[1]])) - - masks, scores, low_res_masks = zero_shot_segment_anything._predict_mask( - image_embeddings=torch.rand(1), input_prompts={}, padding=(0, 0, 0, 0), original_size=(1, 1) - ) - - assert masks.dtype == torch.bool - assert scores.shape[1] == 3 - assert low_res_masks.shape[1] == 3 diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py index 996d1f97cd1..acd9d0c48ca 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py @@ -4,9 +4,13 @@ # SPDX-License-Identifier: Apache-2.0 # +import os +import torch +import numpy as np from typing import Optional, Dict, Any import pytest +from functools import wraps from omegaconf import DictConfig from otx.algorithms.visual_prompting.tasks.inference import InferenceTask, ZeroShotTask @@ -23,6 +27,7 @@ init_environment, MockImageEncoder, ) +import onnxruntime logger = get_logger() @@ -277,6 +282,91 @@ def test_infer(self, mocker): mocker_trainer.assert_called_once() + @e2e_pytest_unit + @pytest.mark.parametrize("export_type", [ExportType.ONNX, ExportType.OPENVINO]) + def test_export(self, mocker, export_type: ExportType): + """Test export.""" + model = self.zero_shot_task.load_model(otx_model=self.zero_shot_task.task_environment.model) + model.prompt_getter.reference_feats = torch.rand(3, 1, 256) + model.prompt_getter.reference_prompts = torch.rand(3, 720, 1280) + mocker.patch.object(self.zero_shot_task, "load_model", return_value=model) + + dataset = generate_visual_prompting_dataset() + output_model = ModelEntity(dataset, self.zero_shot_task.task_environment.get_model_configuration()) + + self.zero_shot_task.export(export_type, output_model, dump_features=False) + + if export_type == ExportType.ONNX: + assert output_model.model_format == ModelFormat.ONNX + assert "visual_prompting_image_encoder.onnx" in output_model.model_adapters + assert "visual_prompting_prompt_getter.onnx" in output_model.model_adapters + assert "visual_prompting_decoder.onnx" in output_model.model_adapters + + elif export_type == ExportType.OPENVINO: + assert output_model.model_format == ModelFormat.OPENVINO + assert "visual_prompting_image_encoder.bin" in output_model.model_adapters + assert "visual_prompting_image_encoder.xml" in output_model.model_adapters + assert "visual_prompting_prompt_getter.bin" in output_model.model_adapters + assert "visual_prompting_prompt_getter.xml" in output_model.model_adapters + assert "visual_prompting_decoder.bin" in output_model.model_adapters + assert "visual_prompting_decoder.xml" in output_model.model_adapters + + assert not output_model.has_xai + + @e2e_pytest_unit + def test_export_to_onnx(self): + """Test _export_to_onnx.""" + onnx_path = { + "visual_prompting_image_encoder": os.path.join( + self.zero_shot_task.output_path, "visual_prompting_image_encoder.onnx" + ), + "visual_prompting_prompt_getter": os.path.join( + self.zero_shot_task.output_path, "visual_prompting_prompt_getter.onnx" + ), + "visual_prompting_decoder": os.path.join(self.zero_shot_task.output_path, "visual_prompting_decoder.onnx"), + } + self.zero_shot_task.model = self.zero_shot_task.load_model(otx_model=self.zero_shot_task.task_environment.model) + self.zero_shot_task.model.prompt_getter.reference_feats = torch.randn(1, 1, 256) + self.zero_shot_task.model.prompt_getter.reference_feats /= ( + self.zero_shot_task.model.prompt_getter.reference_feats.norm(dim=-1, keepdim=True) + ) + + self.zero_shot_task._export_to_onnx(onnx_path) + + image_size = self.zero_shot_task.config.model.image_size + embed_dim = self.zero_shot_task.model.prompt_encoder.embed_dim + embed_size = self.zero_shot_task.model.prompt_encoder.image_embedding_size + mask_input_size = [4 * x for x in embed_size] + onnx_inputs = { + "visual_prompting_image_encoder": { + "images": np.random.random((1, 3, image_size, image_size)).astype(np.float32) + }, + "visual_prompting_prompt_getter": { + "image_embeddings": np.random.randn(1, embed_dim, *embed_size).astype(dtype=np.float32), + "original_size": np.random.randint(low=0, high=image_size * 2, size=(1, 2), dtype=np.int64), + "threshold": np.array([[0.1]], dtype=np.float32), + "num_bg_points": np.random.randint(low=1, high=image_size, size=(1, 1), dtype=np.int64), + }, + "visual_prompting_decoder": { + "image_embeddings": np.zeros((1, embed_dim, *embed_size), dtype=np.float32), + "point_coords": np.random.randint(low=0, high=1024, size=(1, 2, 2)).astype(np.float32), + "point_labels": np.random.randint(low=0, high=4, size=(1, 2)).astype(np.float32), + "mask_input": np.random.randn(1, 1, *mask_input_size).astype(np.float32), + "has_mask_input": np.array([[1]], dtype=np.float32), + }, + } + onnx_outputs = { + "visual_prompting_image_encoder": ["image_embeddings"], + "visual_prompting_prompt_getter": ["total_points_scores", "total_bg_coords"], + "visual_prompting_decoder": ["iou_predictions", "low_res_masks"], + } + + onnx_rt_models = { + k: onnxruntime.InferenceSession(v, providers=["CPUExecutionProvider"]) for k, v in onnx_path.items() + } + for name, onnx_model in onnx_rt_models.items(): + onnx_model.run(onnx_outputs.get(name), onnx_inputs.get(name)) + @e2e_pytest_unit def test_save_model(self, mocker): """Test save_model.""" diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py b/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py index d228687f7ba..0f8ed5daf4b 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py @@ -5,7 +5,7 @@ # from copy import deepcopy -from typing import Optional +from typing import Optional, Dict, Tuple import os import numpy as np @@ -21,6 +21,7 @@ from otx.algorithms.visual_prompting.configs.base import VisualPromptingBaseConfig from otx.algorithms.visual_prompting.tasks.openvino import ( OpenVINOVisualPromptingInferencer, + OpenVINOZeroShotVisualPromptingInferencer, OpenVINOVisualPromptingTask, OTXOpenVinoDataLoader, ) @@ -130,7 +131,9 @@ def test_predict(self, mocker): ), ) mocker_forward = mocker.patch.object( - OpenVINOVisualPromptingInferencer, "forward", return_value={"image_embeddings": np.empty((4, 2, 2))} + OpenVINOVisualPromptingInferencer, + "forward_image_encoder", + return_value={"image_embeddings": np.empty((4, 2, 2))}, ) mocker_forward_decoder = mocker.patch.object( OpenVINOVisualPromptingInferencer, "forward_decoder", return_value=None @@ -149,12 +152,12 @@ def test_predict(self, mocker): assert returned_value == self.fake_annotation @e2e_pytest_unit - def test_forward(self): - """Test forward.""" + def test_forward_image_encoder(self): + """Test forward_image_encoder.""" fake_input = {"images": np.ones((1, 3, 2, 2))} fake_output = {"image_embeddings": np.ones((1, 1, 2, 2))} self.visual_prompting_ov_inferencer.model["image_encoder"].infer_sync.return_value = fake_output - returned_value = self.visual_prompting_ov_inferencer.forward(fake_input) + returned_value = self.visual_prompting_ov_inferencer.forward_image_encoder(fake_input) assert returned_value == fake_output @@ -169,6 +172,149 @@ def test_forward_decoder(self): assert returned_value == fake_output +class TestOpenVINOZeroShotVisualPromptingInferencer: + @pytest.fixture(autouse=True) + def setup(self, mocker): + self.fake_annotation = [ + Annotation( + Polygon(points=[Point(0, 0)]), + id=0, + labels=[ScoredLabel(LabelEntity(name="fake", domain="VISUALPROMPTING"), probability=1.0)], + ) + ] + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.OpenvinoAdapter") + mocker.patch.object(Model, "create_model") + mocker.patch.object( + VisualPromptingToAnnotationConverter, "convert_to_annotation", return_value=self.fake_annotation + ) + self.task_environment = init_environment() + visual_prompting_hparams = self.task_environment.get_hyper_parameters(VisualPromptingBaseConfig) + label_schema = self.task_environment.label_schema + + self.visual_prompting_ov_inferencer = OpenVINOZeroShotVisualPromptingInferencer( + visual_prompting_hparams, + label_schema, + {"image_encoder": "", "prompt_getter": "", "decoder": ""}, + {"image_encoder": "", "prompt_getter": "", "decoder": ""}, + ) + self.visual_prompting_ov_inferencer.model["decoder"] = mocker.patch( + "otx.algorithms.visual_prompting.tasks.openvino.model_wrappers.Decoder", autospec=True + ) + self.visual_prompting_ov_inferencer.model["decoder"]._apply_coords.return_value = np.array([[1, 1]]) + + @e2e_pytest_unit + def test_predict(self, mocker): + """Test predict.""" + mocker_pre_process = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, + "pre_process", + return_value=(torch.zeros((1, 3, 2, 2)), {"original_shape": (4, 4, 1)}), + ) + mocker_forward = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, + "forward_image_encoder", + return_value={"image_embeddings": np.empty((4, 2, 2))}, + ) + mocker_forward_decoder = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, + "forward_prompt_getter", + return_value={"total_points_scores": np.array([[[1, 1, 1]]]), "total_bg_coords": np.array([[[2, 2]]])}, + ) + mocker_forward_decoder = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, "forward_decoder", return_value=None + ) + mocker_post_process = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, "post_process", return_value=(self.fake_annotation, None, None) + ) + fake_input = mocker.Mock(spec=DatasetItemEntity) + + returned_value = self.visual_prompting_ov_inferencer.predict(fake_input) + + mocker_pre_process.assert_called_once() + mocker_forward.assert_called_once() + mocker_forward_decoder.assert_called_once() + mocker_post_process.assert_called_once() + assert returned_value == self.fake_annotation + + @e2e_pytest_unit + @pytest.mark.parametrize( + "postprocess_output,infer_sync_output,expected", + [ + ( + (np.ones((1, 1)), np.ones((3, 3)), 0.9), + {"iou_predictions": np.array([[0.9]]), "low_res_masks": np.ones((1, 1, 2, 2))}, + {"iou_predictions": np.array([[0.9]]), "low_res_masks": np.ones((1, 1, 2, 2))}, + ), + ( + (np.zeros((2, 2)), np.zeros((3, 3)), 0.0), + {"iou_predictions": np.array([[0.9]]), "low_res_masks": np.ones((1, 1, 2, 2))}, + {"iou_predictions": 0.0, "low_res_masks": np.zeros((2, 2))}, + ), + ], + ) + def test_forward_decoder( + self, + mocker, + postprocess_output: Tuple[torch.Tensor, torch.Tensor], + infer_sync_output: Dict[str, np.ndarray], + expected: Dict[str, torch.Tensor], + ): + """Test forward_decoder.""" + mocker.patch.object( + self.visual_prompting_ov_inferencer.model["decoder"], "infer_sync", return_value=infer_sync_output + ) + mocker.patch.object(self.visual_prompting_ov_inferencer, "_postprocess_masks", return_value=postprocess_output) + + result = self.visual_prompting_ov_inferencer.forward_decoder( + inputs={ + "image_embeddings": np.empty((1, 4, 2, 2)), + "point_coords": np.array([[[1, 1]]], dtype=np.float32), + "point_labels": np.array([[1]], dtype=np.float32), + }, + original_size=np.array([3, 3]), + ) + + assert np.all(result["iou_predictions"] == expected["iou_predictions"]) + assert np.all(result["low_res_masks"] == expected["low_res_masks"]) + + @e2e_pytest_unit + @pytest.mark.parametrize( + "high_res_masks,expected_masks,expected_scores", + [ + ( + np.repeat(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])[..., None], 4, axis=-1), + np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.bool_), + 0.9, + ), + ( + np.concatenate( + ( + np.repeat(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])[..., None], 3, axis=-1), + np.zeros((3, 3, 1)), + ), + axis=-1, + ), + np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.bool_), + 0.8, + ), + (np.zeros((3, 3, 4)), np.zeros((3, 3)), 0.0), + ], + ) + def test_postprocess_masks(self, high_res_masks: np.ndarray, expected_masks: np.ndarray, expected_scores: float): + """Test _postprocess_masks.""" + self.visual_prompting_ov_inferencer.model["decoder"].resize_and_crop.return_value = high_res_masks + self.visual_prompting_ov_inferencer.model["decoder"].mask_threshold = 0.0 + self.visual_prompting_ov_inferencer.model["decoder"].image_size = 3 + + _, result_masks, result_scores = self.visual_prompting_ov_inferencer._postprocess_masks( + logits=np.empty((1, 4, 2, 2)), scores=np.array([[0.5, 0.7, 0.8, 0.9]]), original_size=np.array([3, 3]) + ) + + assert result_masks.shape == (3, 3) + assert np.all(result_masks == expected_masks) + assert result_scores == expected_scores + + class TestOTXOpenVinoDataLoader: @pytest.fixture def load_dataloader(self, mocker): diff --git a/tests/unit/algorithms/visual_prompting/test_helpers.py b/tests/unit/algorithms/visual_prompting/test_helpers.py index a9f22c7bf95..c1be0ae3c89 100644 --- a/tests/unit/algorithms/visual_prompting/test_helpers.py +++ b/tests/unit/algorithms/visual_prompting/test_helpers.py @@ -153,7 +153,6 @@ def __init__(self, *args, **kwargs): self.backbone = nn.Linear(1, 1) def forward(self, *args, **kwargs): - # return torch.Tensor([[1]]) return torch.ones((1, 2, 4, 4)) @@ -182,6 +181,9 @@ def __init__(self, *args, **kwargs): def forward(self, *args, **kwargs): return torch.Tensor([[1]]), torch.Tensor([[1]]) + def predict_mask(self, *args, **kwargs): + return self(*args, **kwargs) + class MockScoredLabel: def __init__(self, label: int, name: str = "background"): @@ -199,7 +201,8 @@ def initialize(self): def set_default_thresholds(self, *args, **kwargs): pass + def get_prompt_candidates(self, *args, **kwargs): + return {1: (torch.Tensor([[0, 0, 0.5]]), torch.Tensor([[1, 1]]))} + def forward(self, *args, **kwargs): - return { - MockScoredLabel(label=1, name="label"): (torch.tensor([[0, 0, 0.5], [1, 1, 0.7]]), torch.tensor([[2, 2]])) - } + return torch.tensor([[[0, 0, 0.5], [1, 1, 0.7]]]), torch.tensor([[[2, 2]]])