diff --git a/configs/mmseg/segmentation_rknn-int8_static-320x320.py b/configs/mmseg/segmentation_rknn-int8_static-320x320.py index 2bb9082342..73ab4173c5 100644 --- a/configs/mmseg/segmentation_rknn-int8_static-320x320.py +++ b/configs/mmseg/segmentation_rknn-int8_static-320x320.py @@ -2,6 +2,6 @@ onnx_config = dict(input_shape=[320, 320]) -codebase_config = dict(model_type='rknn') +codebase_config = dict(model_type='rknn', with_argmax=False) backend_config = dict(input_size_list=[[3, 320, 320]]) diff --git a/configs/mmseg/segmentation_static.py b/configs/mmseg/segmentation_static.py index 434a8fae98..416b781ae9 100644 --- a/configs/mmseg/segmentation_static.py +++ b/configs/mmseg/segmentation_static.py @@ -1,2 +1,2 @@ _base_ = ['../_base_/onnx_config.py'] -codebase_config = dict(type='mmseg', task='Segmentation') +codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True) diff --git a/csrc/mmdeploy/codebase/mmseg/segment.cpp b/csrc/mmdeploy/codebase/mmseg/segment.cpp index b1128886c2..c337d09873 100644 --- a/csrc/mmdeploy/codebase/mmseg/segment.cpp +++ b/csrc/mmdeploy/codebase/mmseg/segment.cpp @@ -18,6 +18,7 @@ class ResizeMask : public MMSegmentation { explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) { try { classes_ = cfg["params"]["num_classes"].get(); + with_argmax_ = cfg["params"].value("with_argmax", true); little_endian_ = IsLittleEndian(); } catch (const std::exception &e) { MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg); @@ -31,10 +32,19 @@ class ResizeMask : public MMSegmentation { auto mask = inference_result["output"].get(); MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(), mask.data_type()); - if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) { + if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) { MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape()); return Status(eNotSupported); } + if ((mask.shape(1) != 1) && with_argmax_) { + MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`", + mask.shape()); + return Status(eNotSupported); + } + if (!with_argmax_) { + MMDEPLOY_ERROR("TODO: SDK will support probability featmap soon."); + return Status(eNotSupported); + } auto height = (int)mask.shape(2); auto width = (int)mask.shape(3); @@ -85,6 +95,7 @@ class ResizeMask : public MMSegmentation { protected: int classes_{}; + bool with_argmax_{true}; bool little_endian_; }; diff --git a/docs/en/04-supported-codebases/mmseg.md b/docs/en/04-supported-codebases/mmseg.md index 1bbff04cc9..e475501f1f 100644 --- a/docs/en/04-supported-codebases/mmseg.md +++ b/docs/en/04-supported-codebases/mmseg.md @@ -51,3 +51,5 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl - PSPNet, Fast-SCNN only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) is not supported in most of backends dynamically. - For models only supporting static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`. + +- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs. diff --git a/docs/zh_cn/04-supported-codebases/mmseg.md b/docs/zh_cn/04-supported-codebases/mmseg.md index 66a5b28cb9..a627f2dcee 100644 --- a/docs/zh_cn/04-supported-codebases/mmseg.md +++ b/docs/zh_cn/04-supported-codebases/mmseg.md @@ -51,3 +51,5 @@ mmseg 是一个基于 PyTorch 的开源对象分割工具箱,也是 [OpenMMLab - PSPNet,Fast-SCNN 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。 - 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py` + +- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。 diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index b47c1bf9df..5c5a994e6f 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset from mmdeploy.codebase.base import BaseTask -from mmdeploy.utils import Task, get_input_shape +from mmdeploy.utils import Task, get_codebase_config, get_input_shape from .mmsegmentation import MMSEG_TASK @@ -286,6 +286,9 @@ def get_postprocess(self) -> Dict: postprocess = self.model_cfg.model.decode_head if isinstance(postprocess, list): postprocess = postprocess[-1] + with_argmax = get_codebase_config(self.deploy_cfg).get( + 'with_argmax', True) + postprocess['with_argmax'] = with_argmax return postprocess def get_model_name(self) -> str: diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py index de7c486b3b..72809360d9 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py @@ -109,6 +109,11 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ """ outputs = self.wrapper({self.input_name: imgs}) outputs = self.wrapper.output_to_list(outputs) + if get_codebase_config(self.deploy_cfg).get('with_argmax', + True) is False: + outputs = [ + output.argmax(dim=1, keepdim=True) for output in outputs + ] outputs = [out.detach().cpu().numpy() for out in outputs] return outputs diff --git a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py index 05e9ea96de..dab6b8fba1 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils.config_utils import get_codebase_config from mmdeploy.utils.constants import Backend @@ -25,6 +26,8 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs): """ seg_logit = self.encode_decode(img, img_meta) seg_logit = F.softmax(seg_logit, dim=1) + if get_codebase_config(ctx.cfg).get('with_argmax', True) is False: + return seg_logit seg_pred = seg_logit.argmax(dim=1, keepdim=True) return seg_pred