Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing tile recipes and various tile recipe changes #3942

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/otx/algo/detection/base_models/detection_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,24 @@ def export(
if explain_mode:
msg = "Explain mode is not supported for DETR models yet."
raise NotImplementedError(msg)
return self.postprocess(self._forward_features(batch_inputs), deploy_mode=True)

return self.postprocess(
self._forward_features(batch_inputs),
[meta["img_shape"] for meta in batch_img_metas],
deploy_mode=True,
)

def postprocess(
self,
outputs: dict[str, Tensor],
original_size: tuple[int, int] | None = None,
original_sizes: list[tuple[int, int]],
eugene123tw marked this conversation as resolved.
Show resolved Hide resolved
deploy_mode: bool = False,
) -> dict[str, Tensor] | tuple[list[Tensor], list[Tensor], list[Tensor]]:
"""Post-processes the model outputs.

Args:
outputs (dict[str, Tensor]): The model outputs.
original_size (tuple[int, int], optional): The original size of the input images. Defaults to None.
original_sizes (list[tuple[int, int]]): The original image sizes.
deploy_mode (bool, optional): Whether to run in deploy mode. Defaults to False.

Returns:
Expand All @@ -120,9 +125,9 @@ def postprocess(

# convert bbox to xyxy and rescale back to original size (resize in OTX)
bbox_pred = box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
if not deploy_mode and original_size is not None:
original_size_tensor = torch.tensor(original_size).to(bbox_pred.device)
bbox_pred *= original_size_tensor.repeat(1, 2).unsqueeze(1)
if not deploy_mode:
original_size_tensor = torch.tensor(original_sizes).to(bbox_pred.device)
bbox_pred *= original_size_tensor.flip(1).repeat(1, 2).unsqueeze(1)

# perform scores computation and gather topk results
scores = nn.functional.sigmoid(logits)
Expand All @@ -136,7 +141,7 @@ def postprocess(

scores_list, boxes_list, labels_list = [], [], []

for sc, bb, ll in zip(scores, boxes, labels):
for sc, bb, ll, original_size in zip(scores, boxes, labels, original_sizes):
scores_list.append(sc)
boxes_list.append(
BoundingBoxes(bb, format="xyxy", canvas_size=original_size),
Expand Down
18 changes: 10 additions & 8 deletions src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ def _customize_inputs(
# prepare bboxes for the model
for bb, ll in zip(entity.bboxes, entity.labels):
# convert to cxcywh if needed
converted_bboxes = (
box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb
)
# normalize the bboxes
scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to(
converted_bboxes.device,
)
if len(scaled_bboxes := bb):
converted_bboxes = (
box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb
)
# normalize the bboxes
scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to(
converted_bboxes.device,
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

return {
Expand Down Expand Up @@ -109,7 +110,8 @@ def _customize_outputs(
raise TypeError(msg)
return losses

scores, bboxes, labels = self.model.postprocess(outputs, [img_info.img_shape for img_info in inputs.imgs_info])
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

return DetBatchPredEntity(
batch_size=len(outputs),
Expand Down
7 changes: 5 additions & 2 deletions src/otx/algo/detection/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from otx.algo.common.losses import CrossEntropyLoss, L1Loss
from otx.algo.detection.backbones import CSPDarknet
Expand Down Expand Up @@ -76,13 +76,16 @@ def _exporter(self) -> OTXModelExporter:
raise ValueError(msg)

swap_rgb = not isinstance(self, YOLOXTINY) # only YOLOX-TINY uses RGB
resize_mode: Literal["standard", "fit_to_window_letterbox"] = "fit_to_window_letterbox"
if self.tile_config.enable_tiler:
resize_mode = "standard"
sungchul2 marked this conversation as resolved.
Show resolved Hide resolved

return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
input_size=(1, 3, *self.input_size),
mean=self.mean,
std=self.std,
resize_mode="fit_to_window_letterbox",
resize_mode=resize_mode,
pad_value=114,
swap_rgb=swap_rgb,
via_onnx=True,
Expand Down
3 changes: 3 additions & 0 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
dataset.mem_cache_handler,
dataset.mem_cache_img_max_size,
dataset.max_refetch,
dataset.image_color_channel,
dataset.stack_images,
dataset.to_tv_image,
)
self.tile_config = tile_config
self._dataset = dataset
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/entity/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def unbind(self) -> list[tuple[TileAttrDictList, DetBatchDataEntity]]:
labels=[[] for _ in range(self.batch_size)],
),
)
return list(zip(batch_tile_attr_list, batch_data_entities))
return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))

@classmethod
def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetDataEntity:
Expand Down Expand Up @@ -218,7 +218,7 @@ def unbind(self) -> list[tuple[TileAttrDictList, InstanceSegBatchDataEntity]]:
)
for i in range(0, len(tiles), self.batch_size)
]
return list(zip(batch_tile_attr_list, batch_data_entities))
return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))

@classmethod
def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchInstSegDataEntity:
Expand Down
16 changes: 9 additions & 7 deletions src/otx/core/utils/tile_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
img_infos (list[ImageInfo]): Original image information before tiling.
num_classes (int): Number of classes.
tile_config (TileConfig): Tile configuration.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
"""

def __init__(
Expand Down Expand Up @@ -119,8 +119,8 @@ def merge(
img_ids = []
explain_mode = self.explain_mode

for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
batch_size = tile_preds.batch_size
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
batch_size = len(tile_attrs)
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip(
Expand All @@ -131,6 +131,7 @@ def merge(
tile_preds.scores,
saliency_maps,
feature_vectors,
strict=True,
):
offset_x, offset_y, _, _ = tile_attr["roi"]
tile_bboxes[:, 0::2] += offset_x
Expand All @@ -156,7 +157,7 @@ def merge(

return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]

def _merge_entities(
Expand Down Expand Up @@ -319,8 +320,8 @@ def merge(
img_ids = []
explain_mode = self.explain_mode

for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)]
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
Expand All @@ -329,6 +330,7 @@ def merge(
tile_preds.scores,
tile_preds.masks,
feature_vectors,
strict=True,
):
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
Expand Down Expand Up @@ -363,7 +365,7 @@ def merge(

return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]

def _merge_entities(
Expand Down
82 changes: 82 additions & 0 deletions src/otx/recipe/_base_/data/detection_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
task: DETECTION
input_size:
- 800
- 800
mem_cache_size: 1GB
mem_cache_img_max_size: null
image_color_channel: RGB
stack_images: true
data_format: coco_instances
unannotated_items_ratio: 0.0
tile_config:
enable_tiler: true
enable_adaptive_tiling: true
train_subset:
subset_name: train
transform_lib_type: TORCHVISION
batch_size: 1
num_workers: 2
to_tv_image: false
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: false
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler

val_subset:
subset_name: val
transform_lib_type: TORCHVISION
batch_size: 1
num_workers: 2
to_tv_image: false
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler

test_subset:
subset_name: test
transform_lib_type: TORCHVISION
batch_size: 1
num_workers: 2
to_tv_image: false
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler
6 changes: 1 addition & 5 deletions src/otx/recipe/detection/atss_mobilenetv2_tile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ engine:

callback_monitor: val/map_50

data: ../_base_/data/detection.yaml
data: ../_base_/data/detection_tile.yaml
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
overrides:
gradient_clip_val: 35.0
data:
tile_config:
enable_tiler: true
enable_adaptive_tiling: true

train_subset:
batch_size: 8
sampler:
Expand Down
51 changes: 51 additions & 0 deletions src/otx/recipe/detection/atss_resnext101_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
model:
class_path: otx.algo.detection.atss.ResNeXt101ATSS
init_args:
label_info: 80

optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.004
momentum: 0.9
weight_decay: 0.0001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: DETECTION
device: auto

callback_monitor: val/map_50

data: ../_base_/data/detection_tile.yaml
overrides:
gradient_clip_val: 35.0
callbacks:
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
init_args:
max_interval: 5
decay: -0.025
min_lrschedule_patience: 3

data:
train_subset:
batch_size: 4
sampler:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 4

test_subset:
batch_size: 4
2 changes: 0 additions & 2 deletions src/otx/recipe/detection/rtdetr_101.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ overrides:
init_args:
scale: $(input_size)
keep_ratio: false
transform_bbox: true
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
Expand All @@ -103,7 +102,6 @@ overrides:
init_args:
scale: $(input_size)
keep_ratio: false
transform_bbox: true
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
Expand Down
Loading
Loading