From 2e751609749ee339612e58a10f1791e5147631df Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 16 Oct 2023 15:34:03 +0300 Subject: [PATCH] Fixed a bug in export() that prevented to export model for BS>1 --- src/super_gradients/__init__.py | 2 +- .../yolo_nas_pose/yolo_nas_pose_variants.py | 2 +- .../export_pose_estimation_model_test.py | 38 +++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/__init__.py b/src/super_gradients/__init__.py index d02105deb2..0dc1232d3d 100755 --- a/src/super_gradients/__init__.py +++ b/src/super_gradients/__init__.py @@ -1,4 +1,4 @@ -__version__ = "3.3.0" +__version__ = "3.3.1" from super_gradients.common import init_trainer, is_distributed, object_names from super_gradients.training import losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, QATTrainer diff --git a/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py b/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py index 93c898b7d3..14f3880423 100644 --- a/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py +++ b/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py @@ -71,7 +71,7 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]): topk_candidates = torch.topk(pred_bboxes_conf, dim=1, k=nms_top_k, largest=True, sorted=True) offsets = nms_top_k * torch.arange(pred_bboxes_conf.size(0), device=pred_bboxes_conf.device) - flat_indices = topk_candidates.indices + offsets.reshape(pred_bboxes_conf.size(0), 1) + flat_indices = topk_candidates.indices + offsets.reshape(pred_bboxes_conf.size(0), 1, 1) flat_indices = torch.flatten(flat_indices) pred_poses_and_scores = torch.cat([pred_pose_coords, pred_pose_scores.unsqueeze(3)], dim=3) diff --git a/tests/unit_tests/export_pose_estimation_model_test.py b/tests/unit_tests/export_pose_estimation_model_test.py index b993e0be0c..98f912f410 100644 --- a/tests/unit_tests/export_pose_estimation_model_test.py +++ b/tests/unit_tests/export_pose_estimation_model_test.py @@ -16,6 +16,7 @@ from super_gradients.module_interfaces import ExportablePoseEstimationModel, PoseEstimationModelExportResult from super_gradients.training import models from super_gradients.training.dataloaders import coco2017_val # noqa +from super_gradients.training.models.pose_estimation_models.yolo_nas_pose.yolo_nas_pose_variants import YoloNASPoseDecodingModule from super_gradients.training.processing.processing import ( default_yolo_nas_pose_coco_processing_params, ComposeProcessing, @@ -56,6 +57,23 @@ def setUp(self) -> None: keypoint_colors=np.random.randint(0, 255, size=(20, 3)).tolist(), ) + def test_export_decoding_module_bs_3(self): + num_pre_nms_predictions = 1000 + batch_size = 3 + module = YoloNASPoseDecodingModule(num_pre_nms_predictions) + + pred_bboxes_xyxy = torch.rand(batch_size, 8400, 4) + pred_bboxes_conf = torch.rand(batch_size, 8400, 1).sigmoid() + pred_pose_coords = torch.rand(batch_size, 8400, 20, 2) + pred_pose_scores = torch.rand(batch_size, 8400, 20).sigmoid() + + inputs = (pred_bboxes_xyxy, pred_bboxes_conf, pred_pose_coords, pred_pose_scores) + _ = module([inputs]) # Check that normal forward() works + + with tempfile.TemporaryDirectory() as tmpdirname: + out_path = os.path.join(tmpdirname, "model.onnx") + torch.onnx.export(module, (inputs,), out_path) + def test_export_model_on_small_size(self): with tempfile.TemporaryDirectory() as tmpdirname: for model_type in [ @@ -75,6 +93,26 @@ def test_export_model_on_small_size(self): assert export_result.input_image_shape == (64, 64) print(export_result.usage_instructions) + def test_export_model_with_batch_size_4(self): + with tempfile.TemporaryDirectory() as tmpdirname: + for model_type in [ + Models.YOLO_NAS_POSE_S, + ]: + out_path = os.path.join(tmpdirname, model_type + ".onnx") + model: ExportablePoseEstimationModel = models.get(model_type, num_classes=17) + model.set_dataset_processing_params(**default_yolo_nas_pose_coco_processing_params()) + export_result = model.export( + out_path, + batch_size=4, + input_image_shape=(640, 640), + num_pre_nms_predictions=2000, + max_predictions_per_image=1000, + output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT, + ) + assert export_result.input_image_dtype == torch.uint8 + assert export_result.input_image_shape == (640, 640) + print(export_result.usage_instructions) + def test_the_most_common_export_use_case(self): """ Test the most common export use case - export to ONNX with all default parameters