Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support probability output for segmentation #1379

Merged
merged 6 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_rknn-int8_static-320x320.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

onnx_config = dict(input_shape=[320, 320])

codebase_config = dict(model_type='rknn')
codebase_config = dict(model_type='rknn', with_argmax=False)

backend_config = dict(input_size_list=[[3, 320, 320]])
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_static.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = ['../_base_/onnx_config.py']
codebase_config = dict(type='mmseg', task='Segmentation')
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)
13 changes: 12 additions & 1 deletion csrc/mmdeploy/codebase/mmseg/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ResizeMask : public MMSegmentation {
explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) {
try {
classes_ = cfg["params"]["num_classes"].get<int>();
with_argmax_ = cfg["params"].value("with_argmax", true);
little_endian_ = IsLittleEndian();
} catch (const std::exception &e) {
MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg);
Expand All @@ -31,10 +32,19 @@ class ResizeMask : public MMSegmentation {
auto mask = inference_result["output"].get<Tensor>();
MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(),
mask.shape(), mask.data_type());
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) {
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape());
return Status(eNotSupported);
}
if ((mask.shape(1) != 1) && with_argmax_) {
MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`",
mask.shape());
return Status(eNotSupported);
}
if (!with_argmax_) {
MMDEPLOY_ERROR("TODO: SDK will support probability featmap soon.");
return Status(eNotSupported);
}

auto height = (int)mask.shape(2);
auto width = (int)mask.shape(3);
Expand Down Expand Up @@ -85,6 +95,7 @@ class ResizeMask : public MMSegmentation {

protected:
int classes_{};
bool with_argmax_{true};
bool little_endian_;
};

Expand Down
2 changes: 2 additions & 0 deletions docs/en/04-supported-codebases/mmseg.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) is not supported in most of backends dynamically.

- For models only supporting static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.

- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.
2 changes: 2 additions & 0 deletions docs/zh_cn/04-supported-codebases/mmseg.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@ mmseg 是一个基于 PyTorch 的开源对象分割工具箱,也是 [OpenMMLab
- <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。

- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`

- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。
5 changes: 4 additions & 1 deletion mmdeploy/codebase/mmseg/deploy/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import Dataset

from mmdeploy.codebase.base import BaseTask
from mmdeploy.utils import Task, get_input_shape
from mmdeploy.utils import Task, get_codebase_config, get_input_shape
from .mmsegmentation import MMSEG_TASK


Expand Down Expand Up @@ -286,6 +286,9 @@ def get_postprocess(self) -> Dict:
postprocess = self.model_cfg.model.decode_head
if isinstance(postprocess, list):
postprocess = postprocess[-1]
with_argmax = get_codebase_config(self.deploy_cfg).get(
'with_argmax', True)
postprocess['with_argmax'] = with_argmax
return postprocess

def get_model_name(self) -> str:
Expand Down
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmseg/deploy/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = self.wrapper.output_to_list(outputs)
if get_codebase_config(self.deploy_cfg).get('with_argmax',
True) is False:
outputs = [
output.argmax(dim=1, keepdim=True) for output in outputs
]
outputs = [out.detach().cpu().numpy() for out in outputs]
return outputs

Expand Down
3 changes: 3 additions & 0 deletions mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.config_utils import get_codebase_config
from mmdeploy.utils.constants import Backend


Expand All @@ -25,6 +26,8 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
"""
seg_logit = self.encode_decode(img, img_meta)
seg_logit = F.softmax(seg_logit, dim=1)
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred

Expand Down