Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] IndexError: max(): Expected reduction dim 2 to have non-zero size. #2534

Closed
3 tasks done
IECCLES4 opened this issue Nov 7, 2023 · 7 comments
Closed
3 tasks done
Assignees
Labels
awaiting response mmdet question Further information is requested Stale

Comments

@IECCLES4
Copy link

IECCLES4 commented Nov 7, 2023

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

When I try training on SSD using 1 class I am getting the error
IndexError: max(): Expected reduction dim 2 to have non-zero size. 11/07 15:44:25 - mmengine - ERROR - /home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - pop_mp_output - 80 - mmdeploy.apis.pytorch2onnx.torch2onnx with Call id: 0 failed. exit.
This only happens with SSD and I have used other training methods which do not give any errors. I have seen another Bug report on this but the answer was just a work around and not an actual solution.

Reproduction

python tools/deploy.py configs/mmdet/detection/detection_tensorrt-fp16_static-320x320.py /home/dtl-admin/dev/railsight/mmdetection/configs/ssd/ssd300_fp16_coco_2.py /home/dtl-admin/dev/railsight/mmdetection/checkpoints/2023-08-14_TPE_Trained_Model-290b0e8e.pth /home/dtl-admin/dev/railsight/mmdeploy/data/image-148.png --work-dir mmdeploy_model/ssd_fp16 --device cuda --dump-info

Environment

11/07 15:54:38 - mmengine - INFO - 

11/07 15:54:38 - mmengine - INFO - **********Environmental information**********
11/07 15:54:40 - mmengine - INFO - sys.platform: linux
11/07 15:54:40 - mmengine - INFO - Python: 3.8.10 (default, May 26 2023, 14:05:08) [GCC 9.4.0]
11/07 15:54:40 - mmengine - INFO - CUDA available: True
11/07 15:54:40 - mmengine - INFO - numpy_random_seed: 2147483648
11/07 15:54:40 - mmengine - INFO - GPU 0: Orin
11/07 15:54:40 - mmengine - INFO - CUDA_HOME: /usr/local/cuda
11/07 15:54:40 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.4, V11.4.315
11/07 15:54:40 - mmengine - INFO - GCC: aarch64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
11/07 15:54:40 - mmengine - INFO - PyTorch: 1.11.0
11/07 15:54:40 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.4
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - CUDA Runtime 11.4
  - NVCC architecture flags: -gencode;arch=compute_72,code=sm_72;-gencode;arch=compute_87,code=sm_87
  - CuDNN 8.6
    - Built with CuDNN 8.3.2
  - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CUDA_VERSION=11.4, CUDNN_VERSION=8.3.2, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, FORCE_FALLBACK_CUDA_MPI=1, LAPACK_INFO=open, TORCH_VERSION=1.11.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=ON, USE_NCCL=0, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

11/07 15:54:40 - mmengine - INFO - TorchVision: 0.11.1
11/07 15:54:40 - mmengine - INFO - OpenCV: 4.8.0
11/07 15:54:40 - mmengine - INFO - MMEngine: 0.8.5
11/07 15:54:40 - mmengine - INFO - MMCV: 2.0.0
11/07 15:54:40 - mmengine - INFO - MMCV Compiler: GCC 9.4
11/07 15:54:40 - mmengine - INFO - MMCV CUDA Compiler: 11.4
11/07 15:54:40 - mmengine - INFO - MMDeploy: 1.3.0+1132e82
11/07 15:54:40 - mmengine - INFO - 

11/07 15:54:40 - mmengine - INFO - **********Backend information**********
11/07 15:54:40 - mmengine - INFO - tensorrt:	8.5.2.2
11/07 15:54:40 - mmengine - INFO - tensorrt custom ops:	Available
11/07 15:54:40 - mmengine - INFO - ONNXRuntime:	None
11/07 15:54:40 - mmengine - INFO - pplnn:	None
11/07 15:54:40 - mmengine - INFO - ncnn:	None
11/07 15:54:40 - mmengine - INFO - snpe:	None
11/07 15:54:40 - mmengine - INFO - openvino:	None
11/07 15:54:40 - mmengine - INFO - torchscript:	1.11.0
11/07 15:54:40 - mmengine - INFO - torchscript custom ops:	NotAvailable
11/07 15:54:40 - mmengine - INFO - rknn-toolkit:	None
11/07 15:54:40 - mmengine - INFO - rknn-toolkit2:	None
11/07 15:54:40 - mmengine - INFO - ascend:	None
11/07 15:54:40 - mmengine - INFO - coreml:	None
11/07 15:54:40 - mmengine - INFO - tvm:	None
11/07 15:54:40 - mmengine - INFO - vacc:	None
11/07 15:54:40 - mmengine - INFO - 

11/07 15:54:40 - mmengine - INFO - **********Codebase information**********
11/07 15:54:40 - mmengine - INFO - mmdet:	3.0.0
11/07 15:54:40 - mmengine - INFO - mmseg:	None
11/07 15:54:40 - mmengine - INFO - mmpretrain:	None
11/07 15:54:40 - mmengine - INFO - mmocr:	None
11/07 15:54:40 - mmengine - INFO - mmagic:	None
11/07 15:54:40 - mmengine - INFO - mmdet3d:	None
11/07 15:54:40 - mmengine - INFO - mmpose:	None
11/07 15:54:40 - mmengine - INFO - mmrotate:	None
11/07 15:54:40 - mmengine - INFO - mmaction:	None
11/07 15:54:40 - mmengine - INFO - mmrazor:	None
11/07 15:54:40 - mmengine - INFO - mmyolo:	None

Error traceback

11/07 15:46:48 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/07 15:46:48 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/07 15:46:52 - mmengine - INFO - Start pipeline mmdeploy.apis.pytorch2onnx.torch2onnx in subprocess
11/07 15:46:54 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/07 15:46:54 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
Loads checkpoint by local backend from path: /home/dtl-admin/dev/railsight/mmdetection/checkpoints/2023-08-14_TPE_Trained_Model-290b0e8e.pth
11/07 15:46:56 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future. 
11/07 15:46:56 - mmengine - INFO - Export PyTorch model to ONNX: mmdeploy_model/ssd_fp16/end2end.onnx.
11/07 15:46:57 - mmengine - WARNING - Can not find torch.nn.functional.scaled_dot_product_attention, function rewrite will not be applied
11/07 15:46:57 - mmengine - WARNING - Can not find torch._C._jit_pass_onnx_autograd_function_process, function rewrite will not be applied
11/07 15:46:57 - mmengine - WARNING - Can not find torch._C._jit_pass_onnx_deduplicate_initializers, function rewrite will not be applied
11/07 15:46:57 - mmengine - WARNING - Can not find mmdet.models.utils.transformer.PatchMerging.forward, function rewrite will not be applied
/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:80: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  img_shape = [int(val) for val in img_shape]
/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:80: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  img_shape = [int(val) for val in img_shape]
/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/core/optimizers/function_marker.py:160: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  ys_shape = tuple(int(s) for s in ys.shape)
/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
  warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py:109: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
Process Process-2:
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 98, in torch2onnx
    export(
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 356, in _wrap
    return self.call_function(func_name_, *args, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 326, in call_function
    return self.call_function_local(func_name, *args, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 275, in call_function_local
    return pipe_caller(*args, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/onnx/export.py", line 138, in export
    torch.onnx.export(
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/onnx/utils.py", line 719, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/onnx/optimizer.py", line 27, in model_to_graph__custom_optimizer
    graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/onnx/utils.py", line 499, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/onnx/utils.py", line 440, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/onnx/utils.py", line 391, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/dtl-admin/evenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/onnx/export.py", line 123, in wrapper
    return forward(*arg, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py", line 85, in single_stage_detector__forward
    return __forward_impl(self, batch_inputs, data_samples=data_samples)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/core/optimizers/function_marker.py", line 266, in g
    rets = f(*args, **kwargs)
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py", line 23, in __forward_impl
    output = self.bbox_head.predict(x, data_samples, rescale=False)
  File "/home/dtl-admin/dev/railsight/mmdetection/mmdet/models/dense_heads/base_dense_head.py", line 197, in predict
    predictions = self.predict_by_feat(
  File "/home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py", line 145, in base_dense_head__predict_by_feat
    max_scores, _ = nms_pre_score[..., :-1].max(-1)
IndexError: max(): Expected reduction dim 2 to have non-zero size.
11/07 15:47:01 - mmengine - ERROR - /home/dtl-admin/dev/railsight/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - pop_mp_output - 80 - `mmdeploy.apis.pytorch2onnx.torch2onnx` with Call id: 0 failed. exit.
@RunningLeon
Copy link
Collaborator

hi, is it because of the nms_pre_score has zero dim in its shape? Maybe you can change input_img when calling deploy.py to make sure this tensor have no zero dim.

@RunningLeon RunningLeon self-assigned this Nov 13, 2023
@RunningLeon RunningLeon added awaiting response question Further information is requested mmdet labels Nov 13, 2023
@IECCLES4
Copy link
Author

hi, is it because of the nms_pre_score has zero dim in its shape? Maybe you can change input_img when calling deploy.py to make sure this tensor have no zero dim.

Hi, thank you for suggestion. Just double checking that I understood this correctly you mean the nms_pre in ssd300.py as I did check it and it's default value is 1000 so I don't think that needs changing but I could be wrong.

Copy link

This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.

@github-actions github-actions bot added the Stale label Nov 21, 2023
Copy link

This issue is closed because it has been stale for 5 days. Please open a new issue if you have similar issues or you have any new updates now.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Nov 27, 2023
@rohansaw
Copy link

rohansaw commented May 19, 2024

I am having the same issue. Has anybody figured this out in the meantime?

@BRITISHBOI15
Copy link

I am having the same issue. Has anybody figured this out in the meantime?

Sorry for the late reply but if I remember correctly it was due to the latest version of MMDeploy not supporting one class or something along those lines.

@chuzhdontcode
Copy link
Contributor

chuzhdontcode commented Jul 2, 2024

Hi @RunningLeon, I am facing an issue when exporting the RetinaNet from mmdetection https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/configs/_base_/models/retinanet_r50_fpn.py for a single class case.

The error message is attached below:

│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /detectors/single_stage.py:85 in single_stage_detector__forward                                  │
│                                                                                                  │
│   82 │   # set the metainfo                                                                      │
│   83 │   data_samples = _set_metainfo(data_samples, img_shape)                                   │
│   84 │                                                                                           │
│ ❱ 85 │   return __forward_impl(self, batch_inputs, data_samples=data_samples)                    │
│   86                                                                                             │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdeploy/core/optimizers/funct │
│ ion_marker.py:266 in g                                                                           │
│                                                                                                  │
│   263 │   │   │   args = mark_tensors(args, func, func_id, 'input', ctx, attrs,                  │
│   264 │   │   │   │   │   │   │   │   is_inspect, args_level)                                    │
│   265 │   │   │                                                                                  │
│ ❱ 266 │   │   │   rets = f(*args, **kwargs)                                                      │
│   267 │   │   │                                                                                  │
│   268 │   │   │   ctx = Context(output_names)                                                    │
│   269 │   │   │   func_ret = mark_tensors(rets, func, func_id, 'output', ctx, attrs,             │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /detectors/single_stage.py:23 in __forward_impl                                                  │
│                                                                                                  │
│   20 │   """
│   21 │   x = self.extract_feat(batch_inputs)                                                     │
│   22 │                                                                                           │
│ ❱ 23 │   output = self.bbox_head.predict(x, data_samples, rescale=False)                         │
│   24 │   return output                                                                           │
│   25                                                                                             │
│   26                                                                                             │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdet/models/dense_heads/base_ │
│ dense_head.py:197 in predict                                                                     │
│                                                                                                  │
│   194 │   │                                                                                      │
│   195 │   │   outs = self(x)                                                                     │
│   196 │   │                                                                                      │
│ ❱ 197 │   │   predictions = self.predict_by_feat(                                                │
│   198 │   │   │   *outs, batch_img_metas=batch_img_metas, rescale=rescale)                       │
│   199 │   │   return predictions                                                                 │
│   200                                                                                            │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /dense_heads/base_dense_head.py:145 in base_dense_head__predict_by_feat                          │
│                                                                                                  │
│   142 │   │   │   if self.use_sigmoid_cls:                                                       │
│   143 │   │   │   │   max_scores, _ = nms_pre_score.max(-1)                                      │
│   144 │   │   │   else:                                                                          │
│ ❱ 145 │   │   │   │   max_scores, _ = nms_pre_score[..., :-1].max(-1)                            │
│   146 │   │   │   _, topk_inds = max_scores.topk(pre_topk)                                       │
│   147 │   │   │   bbox_pred, scores, score_factors = gather_topk(                                │
│   148 │   │   │   │   bbox_pred,                                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: max(): Expected reduction dim 2 to have non-zero size.

I modified the classification loss function to employ Cross Entropy (*type*='CrossEntropyLoss', *use_sigmoid*=True, *loss_weight*=1.0), in such the effective bbox_head config would be as follow:

bbox_head=dict(
        type="RetinaHead",
        num_classes=80,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        anchor_generator=dict(
            type="AnchorGenerator",
            octave_base_scale=4,
            scales_per_octave=3,
            ratios=[0.5, 1.0, 2.0],
            strides=[8, 16, 32, 64, 128],
        ),
        bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
        loss_bbox=dict(type="L1Loss", loss_weight=1.0),
    )

To export the model into ONNX, I called the export function from https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/apis/onnx/export.py. Based on my understanding, before torch.onnx.export was invoked, the model is patched with modified child modules and for this particular case, the predict_with_feat() is replaced with base_dense_head__predict_by_feat() in

def base_dense_head__predict_by_feat(
self,
.

After reviewing the code in

def base_dense_head__predict_by_feat(
self,
, I noticed three parts involving the use_sigmoid flag configured in the CrossEntropyLoss, namely:

  1. At the constructor of the RetinaHead : https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/dense_heads/anchor_head.py#L73-L78

  2. At the base_dense_head, there is first slicing of the scores. I presume this is to exclude the background (index num_classes):

    if self.use_sigmoid_cls:
    scores = scores.sigmoid()
    else:
    scores = scores.softmax(-1)[:, :, :-1]
    if with_score_factors:

  3. This is the confusing part, there is a second round of slicing when getting the max_scores:

    # Get maximum scores for foreground classes.
    if self.use_sigmoid_cls:
    max_scores, _ = nms_pre_score.max(-1)
    else:
    max_scores, _ = nms_pre_score[..., :-1].max(-1)
    _, topk_inds = max_scores.topk(pre_topk)

I hope you could explain the reasoning behind this, as it appears that the last object class is excluded when computing the max_scores. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting response mmdet question Further information is requested Stale
Projects
None yet
Development

No branches or pull requests

5 participants