Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visual prompting zero-shot learning (export, IR inference) #2706

Merged
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
65b48f3
Add algobackend & temp configs
sungchul2 Nov 8, 2023
d95f5df
Update config
sungchul2 Nov 8, 2023
a842ff1
WIP
sungchul2 Nov 8, 2023
304952f
Fix to enable `algo_backend`
sungchul2 Nov 9, 2023
04c0abe
(WIP) Update dataset
sungchul2 Nov 10, 2023
6fb1f23
(WIP) Update configs
sungchul2 Nov 10, 2023
bb9f958
(WIP) Update tasks
sungchul2 Nov 10, 2023
5fd0635
(WIP) Update models
sungchul2 Nov 10, 2023
cd94169
Enable `learn` task through otx.train
sungchul2 Nov 13, 2023
ed80aaf
(WIP) enable `infer` (TODO : normalize points)
sungchul2 Nov 13, 2023
b83e959
Fix when `state_dict` is None
sungchul2 Nov 14, 2023
ac67107
Enable `ZeroShotInferenceCallback`
sungchul2 Nov 14, 2023
05c4012
Enable otx infer
sungchul2 Nov 14, 2023
5c3b605
Enable to independently use processor
sungchul2 Nov 20, 2023
cdf02aa
Revert max_steps
sungchul2 Nov 20, 2023
793b71a
Change `postprocess_masks` to `staticmethod`
sungchul2 Nov 29, 2023
d513c5c
Add `PromptGetter` & Enable `learn` and `infer`
sungchul2 Nov 29, 2023
6fc3826
precommit
sungchul2 Nov 29, 2023
8419ccf
Fix args
sungchul2 Nov 30, 2023
f28007e
Fix typo
sungchul2 Nov 30, 2023
3683aa0
Change `id` to `id_`
sungchul2 Nov 30, 2023
a9511bf
Fix import
sungchul2 Nov 30, 2023
ac5dd9a
Fix args
sungchul2 Nov 30, 2023
7aa499a
precommit
sungchul2 Nov 30, 2023
8539a4a
(WIP) Add unit tests
sungchul2 Nov 30, 2023
c75c346
Fix
sungchul2 Dec 5, 2023
40bca8c
Add unit tests
sungchul2 Dec 5, 2023
0e84ff6
Fix
sungchul2 Dec 5, 2023
38d60ba
Add integration tests
sungchul2 Dec 5, 2023
52c9ce0
Merge branch 'develop' into vpm-zsl-integration
sungchul2 Dec 5, 2023
94fe235
precommit
sungchul2 Dec 5, 2023
a53ba03
Update CHANGELOG.md
sungchul2 Dec 5, 2023
977c5d9
Update docstring and type annotations
sungchul2 Dec 6, 2023
9aac960
Fix
sungchul2 Dec 6, 2023
d4a8cd0
Merge branch 'develop' into vpm-zsl-integration
sungchul2 Dec 6, 2023
a67b86a
precommit
sungchul2 Dec 6, 2023
05c8908
Reuse SAM modules for `export` & Add dataset
sungchul2 Dec 7, 2023
3e96865
Fix
sungchul2 Dec 7, 2023
51b2b36
Enable `export`
sungchul2 Dec 7, 2023
46e73a4
Convert fp32
sungchul2 Dec 8, 2023
8db5930
Update logic & tests
sungchul2 Dec 8, 2023
eb2b799
Merge branch 'develop' into vpm-zsl-integration-export
sungchul2 Dec 11, 2023
bc058bd
Fix & Add prompt getter in `model_adapter_keys`
sungchul2 Dec 11, 2023
f7296c0
Initial `Inferencer`, `Task`, and `Model`
sungchul2 Dec 11, 2023
7ce5e6a
Fix to use original mask decoder during inference
sungchul2 Dec 11, 2023
516d8e1
Remove internal loop in `PromptGetter`
sungchul2 Dec 14, 2023
266713f
Update IO
sungchul2 Dec 14, 2023
8266f6e
(WIP) Add unit tests for export
sungchul2 Dec 14, 2023
c09df89
Update `PromptGetter` to use only tensor ops
sungchul2 Dec 15, 2023
ed2c465
Fix issue about `original_size` disappear in onnx graph
sungchul2 Dec 15, 2023
6de25d9
(WIP) Add export unit test
sungchul2 Dec 15, 2023
0c397cb
Update
sungchul2 Dec 15, 2023
e2ecf51
Fix typo
sungchul2 Dec 18, 2023
22d67f1
Update
sungchul2 Dec 18, 2023
457666d
Fix unexpected IF & Update inputs to avoid issues which OV on CPU doe…
sungchul2 Dec 18, 2023
de03105
Enable `PromptGetter` to handle #labels itself
sungchul2 Dec 18, 2023
0915031
Add ov inferencer
sungchul2 Dec 20, 2023
5a300c6
Fix overflow during casting dtype & duplicated cast
sungchul2 Dec 20, 2023
4cfede5
Fix
sungchul2 Dec 21, 2023
d564d38
Add unit&integration tests
sungchul2 Dec 21, 2023
35ce41a
pre-commit
sungchul2 Dec 21, 2023
5c165e8
Merge branch 'develop' into vpm-zsl-integration-export
sungchul2 Dec 21, 2023
f7a77c1
Fix original vpms
sungchul2 Dec 21, 2023
cbc709e
Fix intg & e2e tests
sungchul2 Dec 21, 2023
a2ded78
Change mo CLI to API
sungchul2 Dec 21, 2023
3837665
precommit
sungchul2 Dec 21, 2023
0754f3b
Remove blocks
sungchul2 Dec 21, 2023
5328884
Update CHANGELOG.md
sungchul2 Dec 22, 2023
d2891e0
Avoid repeatedly assigning constant tensors/arrays
sungchul2 Dec 22, 2023
743d7a2
Fix typo
sungchul2 Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file.

### New features

- Add zero-shot visual prompting (https://github.com/openvinotoolkit/training_extensions/pull/2616)
- Add zero-shot visual prompting (<https://github.com/openvinotoolkit/training_extensions/pull/2616>, <https://github.com/openvinotoolkit/training_extensions/pull/2706>)

### Enhancements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .openvino_models import Decoder, ImageEncoder # noqa: F401
from .openvino_models import Decoder, ImageEncoder, PromptGetter # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ def preprocess(
return dict_inputs, meta


class PromptGetter(ImageModel):
"""PromptGetter class for zero-shot visual prompting of openvino model wrapper."""

__model__ = "prompt_getter"

@classmethod
def parameters(cls) -> Dict[str, Any]: # noqa: D102
parameters = super().parameters()
parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)})
parameters.update({"sim_threshold": NumericalValue(value_type=float, default_value=0.5, min=0, max=1)})
parameters.update({"num_bg_points": NumericalValue(value_type=int, default_value=1, min=0, max=1024)})
return parameters


class Decoder(SegmentationModel):
"""Decoder class for visual prompting of openvino model wrapper."""

Expand All @@ -76,6 +90,7 @@ def __init__(
def parameters(cls): # noqa: D102
parameters = super().parameters()
parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)})
parameters.update({"mask_threshold": NumericalValue(value_type=float, default_value=0.0, min=0, max=1)})
return parameters

def _get_outputs(self):
Expand Down Expand Up @@ -174,7 +189,7 @@ def resize_and_crop(self, soft_prediction: np.ndarray, original_size: np.ndarray
)

prepadded_size = self.get_padded_size(original_size, self.image_size).astype(np.int64)
resized_cropped_soft_prediction = resized_soft_prediction[..., : prepadded_size[0], : prepadded_size[1]]
resized_cropped_soft_prediction = resized_soft_prediction[: prepadded_size[0], : prepadded_size[1], ...]

original_size = original_size.astype(np.int64)
h, w = original_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_prompts(dataset_item: DatasetItemEntity, dataset_labels: List[LabelEntit

bboxes = np.array(bboxes)
return dict(
original_size=(height, width),
original_size=np.array((height, width), dtype=np.int64),
gt_masks=gt_masks,
bboxes=bboxes,
points=points, # TODO (sungchul): update point information
Expand Down Expand Up @@ -247,6 +247,20 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]:
class OTXZeroShotVisualPromptingDataset(OTXVisualPromptingDataset):
"""Visual Prompting for Zero-shot learning Dataset Adaptor."""

def __init__(
self,
dataset: DatasetEntity,
image_size: int,
mean: List[float],
std: List[float],
generate_point: bool = False,
generate_bbox: bool = False,
**kwargs,
) -> None:
super().__init__(dataset, image_size, mean, std, offset_bbox=0)
self.generate_point = generate_point
self.generate_bbox = generate_bbox

def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]:
"""Get dataset item.

Expand Down Expand Up @@ -288,7 +302,7 @@ def __init__(
self.config = config
self.dataset = dataset
self.train_type = train_type
# self.kwargs = {}
self.kwargs = {}
if self.train_type == TrainType.Zeroshot:
# check zero-shot configs
if self.config.get("train_batch_size", 1) != 1:
Expand All @@ -300,12 +314,12 @@ def __init__(
)
self.config["train_batch_size"] = 1

# self.kwargs.update(
# {
# "generate_point": self.config.get("generate_point", False),
# "generate_bbox": self.config.get("generate_bbox", False),
# }
# )
self.kwargs.update(
{
"generate_point": self.config.get("generate_point", False),
"generate_bbox": self.config.get("generate_bbox", False),
}
)

self.train_otx_dataset: DatasetEntity
self.val_otx_dataset: DatasetEntity
Expand All @@ -331,7 +345,7 @@ def setup(self, stage: Optional[str] = None) -> None:
mean=mean,
std=std,
offset_bbox=self.config.offset_bbox,
# **self.kwargs,
**self.kwargs,
)

# self.val_dataset = None
Expand All @@ -347,11 +361,7 @@ def setup(self, stage: Optional[str] = None) -> None:

if stage == "predict":
self.predict_dataset = self.DATASETS[self.train_type](
dataset=self.dataset,
image_size=image_size,
mean=mean,
std=std,
# **self.kwargs
dataset=self.dataset, image_size=image_size, mean=mean, std=std, **self.kwargs
)

def summary(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# All rights reserved.
#

from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -76,9 +75,9 @@ def apply_coords(
old_h, old_w = original_size
new_h, new_w = cls.get_preprocess_shape(original_size[0], original_size[1], target_length)
if isinstance(coords, np.ndarray):
coords = deepcopy(coords).astype(np.float32)
coords = coords.astype(float)
else:
coords = deepcopy(coords).to(torch.float32)
coords = coords.to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,26 @@ def collate_fn(batch: List[Any]) -> Dict:
Dict: Collated batch data.
"""

def _convert_empty_to_none(x: str) -> List:
def _convert_empty_to_none(x: str, dtype: torch.dtype = torch.float32) -> List:
"""Convert empty list to None.

Args:
x (str): Key of batch data.
dtype (torch.dtype): Dtype to be applied to tensors.

Returns:
List: List of batch data.
"""
func = torch.stack if x == "gt_masks" else torch.tensor
items = [func(item[x]) for item in batch if item[x] is not None]
items = [func(item[x]).to(dtype) for item in batch if item[x] is not None]
return None if len(items) == 0 else items

index = [item["index"] for item in batch]
images = torch.stack([item["images"] for item in batch])
bboxes = _convert_empty_to_none("bboxes")
points = None # TBD
gt_masks = _convert_empty_to_none("gt_masks")
original_size = [item["original_size"] for item in batch]
gt_masks = _convert_empty_to_none("gt_masks", torch.int32)
original_size = _convert_empty_to_none("original_size")
padding = [item["padding"] for item in batch]
path = [item["path"] for item in batch]
labels = [item["labels"] for item in batch]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import re
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -334,47 +334,36 @@ def select_masks(self, masks: Tensor, iou_preds: Tensor, num_points: int) -> Tup

return masks, iou_preds

def mask_postprocessing(self, masks: Tensor, orig_size: Tensor) -> Tensor:
@staticmethod
def mask_postprocessing(masks: Tensor, input_size: int, orig_size: Tensor) -> Tensor:
"""Postprocesses the predicted masks.

Args:
masks (Tensor): A batch of predicted masks with shape Bx1xHxW.
input_size (int): The size of the image input to the model, in (H, W) format.
Used to remove padding.
orig_size (Tensor): The original image size with shape Bx2.

Returns:
masks (Tensor): The postprocessed masks with shape Bx1xHxW.
"""
masks = F.interpolate(
masks,
size=(self.config.model.image_size, self.config.model.image_size),
mode="bilinear",
align_corners=False,
)

prepadded_size = self.resize_longest_image_size(orig_size, self.config.model.image_size).to(torch.int64)
def resize_longest_image_size(input_image_size: Tensor, longest_side: int) -> Tensor:
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size

masks = F.interpolate(masks, size=(input_size, input_size), mode="bilinear", align_corners=False)

prepadded_size = resize_longest_image_size(orig_size, input_size)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore

orig_size = orig_size.to(torch.int64)
h, w = orig_size[0], orig_size[1]
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks

def resize_longest_image_size(self, input_image_size: Tensor, longest_side: int) -> Tensor:
"""Resizes the longest side of the image to the given size.

Args:
input_image_size (Tensor): The original image size with shape Bx2.
longest_side (int): The size of the longest side.

Returns:
transformed_size (Tensor): The transformed image size with shape Bx2.
"""
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size

######################################################
# forward for training/validation/prediction #
######################################################
Expand Down Expand Up @@ -556,26 +545,26 @@ def predict_step(self, batch, batch_idx) -> Dict[str, Tensor]:
def postprocess_masks(
masks: Tensor,
input_size: Tuple[int, int],
padding: Tuple[int, ...],
original_size: Tuple[int, int],
padding: Union[Tuple[int, ...], Tensor],
original_size: Union[Tuple[int, int], Tensor],
) -> Tensor:
"""Remove padding and upscale masks to the original image size.

Args:
masks (Tensor): Predicted masks from the mask_decoder with (N, 1, H/downsized_ratio, W/downsized_ratio).
input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format.
Used to remove padding.
padding (tuple(int, int, int, int), optional): The padding applied to the image before input to the model,
padding (tuple(int, int, int, int), Tensor): The padding applied to the image before input to the model,
in (left, top, right, bottom) format.
original_size (tuple(int, int)): The original size of the image before resizing for input to the model,
in (H, W) format.
original_size (tuple(int, int), Tensor): The original size of the image before resizing
for input to the model, in (H, W) format.

Returns:
(Tensor): Postprocessed masks in NxHxW format, where (H, W) is given by original_size.
"""
masks = F.interpolate(masks, input_size, mode="bilinear", align_corners=False)
masks = masks[..., : input_size[0] - padding[3], : input_size[1] - padding[2]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
masks = F.interpolate(masks, [int(o) for o in original_size], mode="bilinear", align_corners=False)
return masks.squeeze(1)

def configure_optimizers(self) -> optim:
Expand Down
Loading
Loading