Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed a bug in YoloNASPose.export() that prevented to export model for BS>1 #1530

Merged
merged 4 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/super_gradients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.3.0"
__version__ = "3.3.1"
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved

from super_gradients.common import init_trainer, is_distributed, object_names
from super_gradients.training import losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, QATTrainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):
topk_candidates = torch.topk(pred_bboxes_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_bboxes_conf.size(0), device=pred_bboxes_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_bboxes_conf.size(0), 1)
flat_indices = topk_candidates.indices + offsets.reshape(pred_bboxes_conf.size(0), 1, 1)
flat_indices = torch.flatten(flat_indices)

pred_poses_and_scores = torch.cat([pred_pose_coords, pred_pose_scores.unsqueeze(3)], dim=3)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/export_pose_estimation_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from super_gradients.module_interfaces import ExportablePoseEstimationModel, PoseEstimationModelExportResult
from super_gradients.training import models
from super_gradients.training.dataloaders import coco2017_val # noqa
from super_gradients.training.models.pose_estimation_models.yolo_nas_pose.yolo_nas_pose_variants import YoloNASPoseDecodingModule
from super_gradients.training.processing.processing import (
default_yolo_nas_pose_coco_processing_params,
ComposeProcessing,
Expand Down Expand Up @@ -56,6 +57,23 @@ def setUp(self) -> None:
keypoint_colors=np.random.randint(0, 255, size=(20, 3)).tolist(),
)

def test_export_decoding_module_bs_3(self):
num_pre_nms_predictions = 1000
batch_size = 3
module = YoloNASPoseDecodingModule(num_pre_nms_predictions)

pred_bboxes_xyxy = torch.rand(batch_size, 8400, 4)
pred_bboxes_conf = torch.rand(batch_size, 8400, 1).sigmoid()
pred_pose_coords = torch.rand(batch_size, 8400, 20, 2)
pred_pose_scores = torch.rand(batch_size, 8400, 20).sigmoid()

inputs = (pred_bboxes_xyxy, pred_bboxes_conf, pred_pose_coords, pred_pose_scores)
_ = module([inputs]) # Check that normal forward() works

with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "model.onnx")
torch.onnx.export(module, (inputs,), out_path)

def test_export_model_on_small_size(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for model_type in [
Expand All @@ -75,6 +93,26 @@ def test_export_model_on_small_size(self):
assert export_result.input_image_shape == (64, 64)
print(export_result.usage_instructions)

def test_export_model_with_batch_size_4(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for model_type in [
Models.YOLO_NAS_POSE_S,
]:
out_path = os.path.join(tmpdirname, model_type + ".onnx")
model: ExportablePoseEstimationModel = models.get(model_type, num_classes=17)
model.set_dataset_processing_params(**default_yolo_nas_pose_coco_processing_params())
export_result = model.export(
out_path,
batch_size=4,
input_image_shape=(640, 640),
num_pre_nms_predictions=2000,
max_predictions_per_image=1000,
output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT,
)
assert export_result.input_image_dtype == torch.uint8
assert export_result.input_image_shape == (640, 640)
print(export_result.usage_instructions)

def test_the_most_common_export_use_case(self):
"""
Test the most common export use case - export to ONNX with all default parameters
Expand Down