diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 290c5c8f0e..8f8e28983f 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -1,16 +1,20 @@
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
## Motivation
+
Please describe the motivation of this PR and the goal you want to achieve through this PR.
## Modification
+
Please briefly describe what modification is made in this PR.
## BC-breaking (Optional)
-Does the modification introduce changes that break the back-compatibility of the downstream repos?
+
+Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
## Use cases (Optional)
+
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
## Checklist
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 12f7c87f4c..0b90cbf215 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -57,10 +57,14 @@ jobs:
strategy:
matrix:
python-version: [3.7]
- torch: [1.7.0]
+ torch: [1.7.0, 1.8.0, 1.9.0]
include:
- torch: 1.7.0
torchvision: 0.8.1
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
@@ -85,7 +89,7 @@ jobs:
strategy:
matrix:
python-version: [3.7]
- torch: [1.3.1, 1.4.0, 1.5.1, 1.6.0, 1.7.0]
+ torch: [1.3.1, 1.4.0, 1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
include:
- torch: 1.3.1
torchvision: 0.4.2
@@ -97,6 +101,10 @@ jobs:
torchvision: 0.7.0
- torch: 1.7.0
torchvision: 0.8.1
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
@@ -121,7 +129,7 @@ jobs:
coverage xml
coverage report -m
- build_cuda:
+ build_cu101:
runs-on: ubuntu-18.04
env:
CUDA: 10.1.105-1
@@ -132,7 +140,7 @@ jobs:
strategy:
matrix:
python-version: [3.7]
- torch: [1.3.1, 1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101]
+ torch: [1.3.1, 1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
include:
- torch: 1.3.1
torchvision: 0.4.2
@@ -142,12 +150,14 @@ jobs:
torchvision: 0.7.0+cu101
- torch: 1.7.0+cu101
torchvision: 0.8.1+cu101
+ - torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
- python-version: 3.6
- torch: 1.7.0+cu101
- torchvision: 0.8.1+cu101
+ torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
- python-version: 3.8
- torch: 1.7.0+cu101
- torchvision: 0.8.1+cu101
+ torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
steps:
- uses: actions/checkout@v2
@@ -199,11 +209,81 @@ jobs:
name: codecov-umbrella
fail_ci_if_error: false
+ build_cu102:
+ runs-on: ubuntu-18.04
+ env:
+ CUDA: 10.2.89-1
+ CUDA_SHORT: 10.2
+ UBUNTU_VERSION: ubuntu1804
+ FORCE_CUDA: 1
+ MMCV_CUDA_ARGS: -gencode=arch=compute_61,code=sm_61
+ strategy:
+ matrix:
+ python-version: [3.7]
+ torch: [1.9.0+cu102]
+ include:
+ - torch: 1.9.0+cu102
+ torchvision: 0.10.0+cu102
+ - python-version: 3.6
+ torch: 1.9.0+cu102
+ torchvision: 0.10.0+cu102
+ - python-version: 3.8
+ torch: 1.9.0+cu102
+ torchvision: 0.10.0+cu102
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install CUDA
+ run: |
+ export INSTALLER=cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb
+ wget http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER}
+ sudo dpkg -i ${INSTALLER}
+ wget https://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/7fa2af80.pub
+ sudo apt-key add 7fa2af80.pub
+ sudo apt update -qq
+ sudo apt install -y cuda-${CUDA_SHORT/./-} cuda-cufft-dev-${CUDA_SHORT/./-}
+ sudo apt clean
+ export CUDA_HOME=/usr/local/cuda-${CUDA_SHORT}
+ export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${CUDA_HOME}/include:${LD_LIBRARY_PATH}
+ export PATH=${CUDA_HOME}/bin:${PATH}
+ sudo apt-get install -y ninja-build
+ - name: Install Pillow
+ run: pip install Pillow==6.2.2
+ if: ${{matrix.torchvision == '0.4.2'}}
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Install system dependencies
+ run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
+ - name: Build and install
+ run: rm -rf .eggs && pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests and generate coverage report
+ run: |
+ pip install -r requirements/test.txt
+ coverage run --branch --source=mmcv -m pytest tests/
+ coverage xml
+ coverage report -m
+ # Only upload coverage report for python3.7 && pytorch1.6
+ - name: Upload coverage to Codecov
+ if: ${{matrix.torch == '1.6.0+cu102' && matrix.python-version == '3.7'}}
+ uses: codecov/codecov-action@v1.0.14
+ with:
+ file: ./coverage.xml
+ flags: unittests
+ env_vars: OS,PYTHON
+ name: codecov-umbrella
+ fail_ci_if_error: false
+
build_macos:
runs-on: macos-latest
strategy:
matrix:
- torch: [1.3.1, 1.5.1, 1.6.0, 1.7.0]
+ torch: [1.3.1, 1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
include:
- torch: 1.3.1
torchvision: 0.4.2
@@ -213,6 +293,10 @@ jobs:
torchvision: 0.7.0
- torch: 1.7.0
torchvision: 0.8.1
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
diff --git a/.github/workflows/build_pat.yml b/.github/workflows/build_pat.yml
index ce72e78088..bc45ff2a2b 100644
--- a/.github/workflows/build_pat.yml
+++ b/.github/workflows/build_pat.yml
@@ -9,9 +9,9 @@ jobs:
build_parrots:
runs-on: ubuntu-18.04
container:
- image: ghcr.io/sunnyxiaohu/parrots-mmcv:1.2.1
+ image: ghcr.io/zhouzaida/parrots-mmcv:1.3.4
credentials:
- username: sunnyxiaohu
+ username: zhouzaida
password: ${{ secrets.CR_PAT }}
steps:
diff --git a/.gitignore b/.gitignore
index 43e4a4082f..b8e4f612f8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -65,6 +65,7 @@ instance/
# Sphinx documentation
docs/_build/
+docs_zh_CN/_build/
# PyBuilder
target/
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 7a987d9b1b..f347c6c10e 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,11 +1,11 @@
-# Contributing to OpenMMLab
+## Contributing to OpenMMLab
All kinds of contributions are welcome, including but not limited to the following.
- Fixes (typo, bugs)
- New features and components
-## Workflow
+### Workflow
1. fork and pull the latest OpenMMLab repository
2. checkout a new branch (do not use master branch for PRs)
@@ -14,9 +14,9 @@ All kinds of contributions are welcome, including but not limited to the followi
Note: If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
-## Code style
+### Code style
-### Python
+#### Python
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
@@ -64,6 +64,6 @@ After this on every commit check code linters and formatter will be enforced.
>Before you create a PR, make sure that your code lints and is formatted by yapf.
-### C++ and CUDA
+#### C++ and CUDA
We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
diff --git a/README.md b/README.md
index a8045e9513..1bfccb32b0 100644
--- a/README.md
+++ b/README.md
@@ -170,14 +170,27 @@ pip install mmcv
c. Install full version with custom operators for onnxruntime
-- Check [here](docs/onnxruntime_op.md) for detailed instruction.
+- Check [here](docs/deployment/onnxruntime_op.md) for detailed instruction.
-If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/build.html).
+If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/get_started/build.html).
## FAQ
If you face some installation issues, CUDA related issues or RuntimeErrors,
-you may first refer to this [Trouble Shooting Page](https://mmcv.readthedocs.io/en/latest/trouble_shooting.html).
+you may first refer to this [Frequently Asked Questions](https://mmcv.readthedocs.io/en/latest/faq.html).
+
+## Citation
+
+If you find this project useful in your research, please consider cite:
+
+```latex
+@misc{mmcv,
+ title={{MMCV: OpenMMLab} Computer Vision Foundation},
+ author={MMCV Contributors},
+ howpublished = {\url{https://github.com/open-mmlab/mmcv}},
+ year={2018}
+}
+```
## Contributing
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f12f5e449a..376901147a 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -167,13 +167,13 @@ pip install mmcv
c. 安装完整版并且编译 onnxruntime 的自定义算子
-- 详细的指南请查看 [这里](docs/onnxruntime_op.md)。
+- 详细的指南请查看 [这里](docs/deployment/onnxruntime_op.md)。
-如果想从源码编译 MMCV,请参考[该文档](https://mmcv.readthedocs.io/en/latest/build.html)。
+如果想从源码编译 MMCV,请参考[该文档](https://mmcv.readthedocs.io/en/latest/get_started/build.html)。
## FAQ
-如果你遇到了安装问题,CUDA 相关的问题或者 RuntimeErrors,可以首先参考[问题解决页面](https://mmcv.readthedocs.io/en/latest/trouble_shooting.html) 看是否已经有解决方案。
+如果你遇到了安装问题,CUDA 相关的问题或者 RuntimeErrors,可以首先参考[问题解决页面](https://mmcv.readthedocs.io/en/latest/faq.html) 看是否已经有解决方案。
## 贡献指南
diff --git a/docs/api.rst b/docs/api.rst
index 36eed8269a..daa3e65263 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -1,4 +1,4 @@
-API Documentation
+API Reference
=================
diff --git a/docs/community.rst b/docs/community.rst
new file mode 100644
index 0000000000..33a24f671d
--- /dev/null
+++ b/docs/community.rst
@@ -0,0 +1,7 @@
+Community
+===========
+
+.. toctree::
+ :maxdepth: 2
+
+ community/contributing.md
diff --git a/docs/community/contributing.md b/docs/community/contributing.md
new file mode 120000
index 0000000000..f939e75f21
--- /dev/null
+++ b/docs/community/contributing.md
@@ -0,0 +1 @@
+../../CONTRIBUTING.md
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
index 6c389f551f..2307a980db 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -27,8 +27,8 @@
# -- Project information -----------------------------------------------------
project = 'mmcv'
-copyright = '2018-2019, Kai Chen'
-author = 'Kai Chen'
+copyright = '2018-2021, OpenMMLab'
+author = 'MMCV Authors'
# The short X.Y version
version = __version__
@@ -54,9 +54,7 @@
'sphinx_markdown_tables'
] # yapf: disable
-autodoc_mock_imports = [
- 'cv2', 'mmcv._ext', 'mmcv.utils.ext_loader', 'torchvision'
-]
+autodoc_mock_imports = ['mmcv._ext', 'mmcv.utils.ext_loader', 'torchvision']
autosectionlabel_prefix_document = True
# Add any paths that contain templates here, relative to this directory.
diff --git a/docs/deployment.rst b/docs/deployment.rst
index 68f81f9520..bfbf776ac0 100644
--- a/docs/deployment.rst
+++ b/docs/deployment.rst
@@ -1,11 +1,11 @@
Deployment
-========
+================
.. toctree::
:maxdepth: 2
- onnx.md
- onnxruntime_op.md
- onnxruntime_custom_ops.md
- tensorrt_plugin.md
- tensorrt_custom_ops.md
+ deployment/onnx.md
+ deployment/onnxruntime_op.md
+ deployment/onnxruntime_custom_ops.md
+ deployment/tensorrt_plugin.md
+ deployment/tensorrt_custom_ops.md
diff --git a/docs/onnx.md b/docs/deployment/onnx.md
similarity index 83%
rename from docs/onnx.md
rename to docs/deployment/onnx.md
index c561622379..90c5540071 100644
--- a/docs/onnx.md
+++ b/docs/deployment/onnx.md
@@ -1,4 +1,4 @@
-# Introduction of `onnx` module in MMCV (Experimental)
+# Introduction of onnx module in MMCV (Experimental)
## register_extra_symbolics
diff --git a/docs/onnxruntime_custom_ops.md b/docs/deployment/onnxruntime_custom_ops.md
similarity index 100%
rename from docs/onnxruntime_custom_ops.md
rename to docs/deployment/onnxruntime_custom_ops.md
diff --git a/docs/onnxruntime_op.md b/docs/deployment/onnxruntime_op.md
similarity index 95%
rename from docs/onnxruntime_op.md
rename to docs/deployment/onnxruntime_op.md
index e43ce70fc6..e8956fd7f5 100644
--- a/docs/onnxruntime_op.md
+++ b/docs/deployment/onnxruntime_op.md
@@ -20,10 +20,10 @@
| [SoftNMS](onnxruntime_custom_ops.md#softnms) | Y | N | 1.2.3 |
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
-| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
-| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
-| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
-| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |
+| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | 1.3.1 |
+| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | 1.3.4 |
+| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
+| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |
## How to build custom operators for ONNX Runtime
diff --git a/docs/tensorrt_custom_ops.md b/docs/deployment/tensorrt_custom_ops.md
similarity index 65%
rename from docs/tensorrt_custom_ops.md
rename to docs/deployment/tensorrt_custom_ops.md
index da696f03e9..1ef48ece06 100644
--- a/docs/tensorrt_custom_ops.md
+++ b/docs/deployment/tensorrt_custom_ops.md
@@ -33,6 +33,30 @@
- [Inputs](#inputs-4)
- [Outputs](#outputs-4)
- [Type Constraints](#type-constraints-4)
+ - [cummax](#cummax)
+ - [Description](#description-5)
+ - [Parameters](#parameters-5)
+ - [Inputs](#inputs-5)
+ - [Outputs](#outputs-5)
+ - [Type Constraints](#type-constraints-5)
+ - [cummin](#cummin)
+ - [Description](#description-6)
+ - [Parameters](#parameters-6)
+ - [Inputs](#inputs-6)
+ - [Outputs](#outputs-6)
+ - [Type Constraints](#type-constraints-6)
+ - [MMCVInstanceNormalization](#mmcvinstancenormalization)
+ - [Description](#description-7)
+ - [Parameters](#parameters-7)
+ - [Inputs](#inputs-7)
+ - [Outputs](#outputs-7)
+ - [Type Constraints](#type-constraints-7)
+ - [MMCVModulatedDeformConv2d](#mmcvmodulateddeformconv2d)
+ - [Description](#description-8)
+ - [Parameters](#parameters-8)
+ - [Inputs](#inputs-8)
+ - [Outputs](#outputs-8)
+ - [Type Constraints](#type-constraints-8)
@@ -227,3 +251,145 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints
- T:tensor(float32, Linear)
+
+## cummax
+
+### Description
+
+Returns a namedtuple (`values`, `indices`) where `values` is the cumulative maximum of elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`.
+
+### Parameters
+
+| Type | Parameter | Description |
+| ----- | --------- | --------------------------------------- |
+| `int` | `dim` | The dimension to do the operation over. |
+
+### Inputs
+
+
+- inputs[0]: T
+- The input tensor.
+
+
+### Outputs
+
+
+- outputs[0]: T
+- Output values.
+- outputs[1]: (int32, Linear)
+- Output indices.
+
+
+### Type Constraints
+
+- T:tensor(float32, Linear)
+
+## cummin
+
+### Description
+
+Returns a namedtuple (`values`, `indices`) where `values` is the cumulative minimum of elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`.
+
+### Parameters
+
+| Type | Parameter | Description |
+| ----- | --------- | --------------------------------------- |
+| `int` | `dim` | The dimension to do the operation over. |
+
+### Inputs
+
+
+- inputs[0]: T
+- The input tensor.
+
+
+### Outputs
+
+
+- outputs[0]: T
+- Output values.
+- outputs[1]: (int32, Linear)
+- Output indices.
+
+
+### Type Constraints
+
+- T:tensor(float32, Linear)
+
+## MMCVInstanceNormalization
+
+### Description
+
+Carries out instance normalization as described in the paper https://arxiv.org/abs/1607.08022.
+
+y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance are computed per instance per channel.
+
+### Parameters
+
+| Type | Parameter | Description |
+| ------- | --------- | -------------------------------------------------------------------- |
+| `float` | `epsilon` | The epsilon value to use to avoid division by zero. Default is 1e-05 |
+
+### Inputs
+
+
+- input: T
+- Input data tensor from the previous operator; dimensions for image case are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data. For non image case, the dimensions are in the form of (N x C x D1 x D2 ... Dn), where N is the batch size.
+- scale: T
+- The input 1-dimensional scale tensor of size C.
+- B: T
+- The input 1-dimensional bias tensor of size C.
+
+
+### Outputs
+
+
+- output: T
+- The output tensor of the same shape as input.
+
+
+### Type Constraints
+
+- T:tensor(float32, Linear)
+
+## MMCVModulatedDeformConv2d
+
+### Description
+
+Perform Modulated Deformable Convolution on input feature, read [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/abs/1811.11168?from=timeline) for detail.
+
+### Parameters
+
+| Type | Parameter | Description |
+| -------------- | ------------------ | ------------------------------------------------------------------------------------- |
+| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
+| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
+| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
+| `int` | `deformable_group` | Groups of deformable offset. |
+| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. |
+
+### Inputs
+
+
+- inputs[0]: T
+- Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the number of channels, inH and inW are the height and width of the data.
+- inputs[1]: T
+- Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.
+- inputs[2]: T
+- Input mask; 4-D tensor of shape (N, deformable_group* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.
+- inputs[3]: T
+- Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).
+- inputs[4]: T, optional
+- Input weight; 1-D tensor of shape (output_channel).
+
+
+### Outputs
+
+
+- outputs[0]: T
+- Output feature; 4-D tensor of shape (N, output_channel, outH, outW).
+
+
+### Type Constraints
+
+- T:tensor(float32, Linear)
diff --git a/docs/tensorrt_plugin.md b/docs/deployment/tensorrt_plugin.md
similarity index 78%
rename from docs/tensorrt_plugin.md
rename to docs/deployment/tensorrt_plugin.md
index 5ed62d1ba3..325c79762e 100644
--- a/docs/tensorrt_plugin.md
+++ b/docs/deployment/tensorrt_plugin.md
@@ -24,13 +24,17 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
## List of TensorRT plugins supported in MMCV
-| ONNX Operator | TensorRT Plugin | MMCV Releases |
-| :---------------: | :-------------------------------------------------------------: | :-----------: |
-| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 |
-| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
-| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
-| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
-| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | master |
+| ONNX Operator | TensorRT Plugin | MMCV Releases |
+| :-----------------------: | :-----------------------------------------------------------------------------: | :-----------: |
+| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 |
+| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
+| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
+| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
+| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
+| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | 1.3.5 |
+| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | 1.3.5 |
+| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | 1.3.5 |
+| MMCVModulatedDeformConv2d | [MMCVModulatedDeformConv2d](./tensorrt_custom_ops.md#mmcvmodulateddeformconv2d) | master |
Notes
@@ -86,7 +90,7 @@ Here is an example.
import torch
import onnx
-from mmcv.tensorrt import (TRTWraper, onnx2trt, save_trt_engine,
+from mmcv.tensorrt import (TRTWrapper, onnx2trt, save_trt_engine,
is_tensorrt_plugin_loaded)
assert is_tensorrt_plugin_loaded(), 'Requires to complie TensorRT plugins in mmcv'
@@ -115,7 +119,7 @@ trt_engine = onnx2trt(
save_trt_engine(trt_engine, trt_file)
# Run inference with TensorRT
-trt_model = TRTWraper(trt_file, ['input'], ['output'])
+trt_model = TRTWrapper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': inputs})
@@ -159,7 +163,7 @@ Below are the main steps:
### Reminders
-- Some of the [custom ops](https://mmcv.readthedocs.io/en/latest/ops.html) in `mmcv` have their cuda implementations, which could be refered.
+- Some of the [custom ops](https://mmcv.readthedocs.io/en/latest/ops.html) in `mmcv` have their cuda implementations, which could be referred.
## Known Issues
diff --git a/docs/trouble_shooting.md b/docs/faq.md
similarity index 98%
rename from docs/trouble_shooting.md
rename to docs/faq.md
index fb0976d072..ab0dd135f9 100644
--- a/docs/trouble_shooting.md
+++ b/docs/faq.md
@@ -1,4 +1,4 @@
-## Trouble Shooting
+## Frequently Asked Questions
We list some common troubles faced by many users and their corresponding solutions here.
Feel free to enrich the list if you find any frequent issues and have ways to help others to solve them.
diff --git a/docs/get_started.rst b/docs/get_started.rst
new file mode 100644
index 0000000000..e8366a887a
--- /dev/null
+++ b/docs/get_started.rst
@@ -0,0 +1,9 @@
+Get started
+===================
+
+.. toctree::
+ :maxdepth: 2
+
+ get_started/introduction.md
+ get_started/installation.md
+ get_started/build.md
diff --git a/docs/build.md b/docs/get_started/build.md
similarity index 100%
rename from docs/build.md
rename to docs/get_started/build.md
diff --git a/docs/get_started/installation.md b/docs/get_started/installation.md
new file mode 100644
index 0000000000..115270eda7
--- /dev/null
+++ b/docs/get_started/installation.md
@@ -0,0 +1,137 @@
+## Installation
+
+There are two versions of MMCV:
+
+- **mmcv-full**: comprehensive, with full features and various CUDA ops out of box. It takes longer time to build.
+- **mmcv**: lite, without CUDA ops but all other features, similar to mmcv<1.0.0. It is useful when you do not need those CUDA ops.
+
+**Note**: Do not install both versions in the same environment, otherwise you may encounter errors like `ModuleNotFound`. You need to uninstall one before installing the other. `Installing the full verion is highly recommended if CUDA is avaliable`.
+
+a. Install the full version.
+
+Before installing mmcv-full, make sure that PyTorch has been successfully installed following the [official guide](https://pytorch.org/).
+
+We provide pre-built mmcv packages (recommended) with different PyTorch and CUDA versions to simplify the building.
+
+i. Install the latest version.
+
+The rule for installing the latest ``mmcv-full`` is as follows:
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
+```
+
+Please replace ``{cu_version}`` and ``{torch_version}`` in the url to your desired one. For example,
+to install the latest ``mmcv-full`` with ``CUDA 11`` and ``PyTorch 1.7.0``, use the following command:
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
+```
+
+For more details, please refer the the following tables and delete ``=={mmcv_version}``.
+
+ii. Install a specified version.
+
+The rule for installing a specified ``mmcv-full`` is as follows:
+
+```shell
+pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
+```
+
+First of all, please refer to the Releases and replace ``{mmcv_version}`` a specified one. e.g. ``1.2.2``.
+Then replace ``{cu_version}`` and ``{torch_version}`` in the url to your desired versions. For example,
+to install ``mmcv-full==1.2.2`` with ``CUDA 11`` and ``PyTorch 1.7.0``, use the following command:
+
+```shell
+pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
+```
+
+For more details, please refer the the following tables.
+
+
+
+
+ CUDA |
+ torch 1.8 |
+ torch 1.7 |
+ torch 1.6 |
+ torch 1.5 |
+ torch 1.4 |
+ torch 1.3 |
+
+
+ 11.1 |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html |
+ |
+ |
+ |
+ |
+ |
+
+
+ 11.0 |
+ |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html |
+ |
+ |
+ |
+ |
+
+
+ 10.2 |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.5.0/index.html |
+ |
+ |
+
+
+ 10.1 |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.4.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.3.0/index.html |
+
+
+ 9.2 |
+ |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.5.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.4.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.3.0/index.html |
+
+
+ cpu |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.5.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.4.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.3.0/index.html |
+
+
+
+
+Another way is to compile locally by running
+
+```python
+pip install mmcv-full
+```
+
+Note that the local compiling may take up to 10 mins.
+
+b. Install the lite version.
+
+```python
+pip install mmcv
+```
+
+c. Install full version with custom operators for onnxruntime
+
+- Check [here](docs/onnxruntime_op.md) for detailed instruction.
+
+If you would like to build MMCV from source, please refer to the [guide](build.md).
diff --git a/docs/get_started/introduction.md b/docs/get_started/introduction.md
new file mode 100644
index 0000000000..2a0f1564f8
--- /dev/null
+++ b/docs/get_started/introduction.md
@@ -0,0 +1,33 @@
+## Introduction
+
+
+
+
+
+[![PyPI](https://img.shields.io/pypi/v/mmcv)](https://pypi.org/project/mmcv) [![badge](https://github.com/open-mmlab/mmcv/workflows/build/badge.svg)](https://github.com/open-mmlab/mmcv/actions) [![codecov](https://codecov.io/gh/open-mmlab/mmcv/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmcv) [![license](https://img.shields.io/github/license/open-mmlab/mmcv.svg)](https://github.com/open-mmlab/mmcv/blob/master/LICENSE)
+
+MMCV is a foundational library for computer vision research and supports many
+research projects as below:
+
+- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
+- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
+- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
+- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
+- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
+- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
+- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
+- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition and understanding toolbox.
+- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
+
+It provides the following functionalities.
+
+- Universal IO APIs
+- Image/Video processing
+- Image and annotation visualization
+- Useful utilities (progress bar, timer, ...)
+- PyTorch runner with hooking mechanism
+- Various CNN architectures
+- High-quality implementation of common CUDA ops
+
+Note: MMCV requires Python 3.6+.
diff --git a/docs/index.rst b/docs/index.rst
index 444ba1f2ca..64e796f9b1 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,26 +1,17 @@
+Welcome to MMCV's documentation!
+================================
-.. mdinclude:: readme.md
-
-Contents
-========
+You can switch between Chinese and English documents in the lower-left corner of the layout.
.. toctree::
:maxdepth: 2
- io.md
- image.md
- video.md
- visualization.md
- utils.md
- runner.md
- registry.md
- cnn.md
- ops.md
- build.md
+ get_started.rst
deployment.rst
- trouble_shooting.md
+ understand_mmcv.rst
api.rst
-
+ faq.md
+ community.rst
Indices and tables
diff --git a/docs/readme.md b/docs/readme.md
deleted file mode 120000
index 94389aee61..0000000000
--- a/docs/readme.md
+++ /dev/null
@@ -1 +0,0 @@
-../README.md
diff --git a/docs/runner.md b/docs/runner.md
deleted file mode 100644
index 95dbe31637..0000000000
--- a/docs/runner.md
+++ /dev/null
@@ -1,6 +0,0 @@
-## Runner
-
-The runner module aims to help users to start training with less code, while stays
-flexible and configurable.
-
-Documentation and examples are still on going.
diff --git a/docs/understand_mmcv.rst b/docs/understand_mmcv.rst
new file mode 100644
index 0000000000..ef26d386f5
--- /dev/null
+++ b/docs/understand_mmcv.rst
@@ -0,0 +1,15 @@
+Understand MMCV
+=================
+
+.. toctree::
+ :maxdepth: 2
+
+ understand_mmcv/config.md
+ understand_mmcv/registry.md
+ understand_mmcv/runner.md
+ understand_mmcv/io.md
+ understand_mmcv/data_process.md
+ understand_mmcv/visualization.md
+ understand_mmcv/cnn.md
+ understand_mmcv/ops.md
+ understand_mmcv/utils.md
diff --git a/docs/cnn.md b/docs/understand_mmcv/cnn.md
similarity index 98%
rename from docs/cnn.md
rename to docs/understand_mmcv/cnn.md
index 41fddc8179..8b7d485ae5 100644
--- a/docs/cnn.md
+++ b/docs/understand_mmcv/cnn.md
@@ -370,9 +370,9 @@ Let us introduce the usage of `initialize` in detail.
`BaseModule` is inherited from `torch.nn.Module`, and the only different between them is that `BaseModule` implements `init_weight`.
- `Sequential` is inhertied from `BaseModule` and `torch.nn.Sequential`.
+ `Sequential` is inherited from `BaseModule` and `torch.nn.Sequential`.
- `ModuleList` is inhertied from `BaseModule` and `torch.nn.ModuleList`.
+ `ModuleList` is inherited from `BaseModule` and `torch.nn.ModuleList`.
`````python
import torch.nn as nn
@@ -534,5 +534,5 @@ The following types are supported for `filename` argument of `mmcv.load_checkpoi
- filepath: The filepath of the checkpoint.
- `http://xxx` and `https://xxx`: The link to download the checkpoint. The `SHA256` postfix should be contained in the filename.
-- `torchvison://xxx`: The model links in `torchvision.models`.Please refer to [torchvision](https://pytorch.org/docs/stable/torchvision/models.html) for details.
+- `torchvision://xxx`: The model links in `torchvision.models`.Please refer to [torchvision](https://pytorch.org/docs/stable/torchvision/models.html) for details.
- `open-mmlab://xxx`: The model links or filepath provided in default and additional json files.
diff --git a/docs/utils.md b/docs/understand_mmcv/config.md
similarity index 63%
rename from docs/utils.md
rename to docs/understand_mmcv/config.md
index bcc71bfdff..2d0447f2f0 100644
--- a/docs/utils.md
+++ b/docs/understand_mmcv/config.md
@@ -1,6 +1,4 @@
-## Utils
-
-### Config
+## Config
`Config` class is used for manipulating config and config files. It supports
loading configs from multiple file formats including **python**, **json** and **yaml**.
@@ -69,7 +67,7 @@ a = 1
b = dict(b1=[0, 1, 2], b2=None)
```
-#### Inherit from base config without overlaped keys
+### Inherit from base config without overlapped keys
`config_b.py`
@@ -90,7 +88,7 @@ d = 'string'
New fields in `config_b.py` are combined with old fields in `config_a.py`
-#### Inherit from base config with overlaped keys
+### Inherit from base config with overlapped keys
`config_c.py`
@@ -110,7 +108,7 @@ c = (1, 2)
`b.b2=None` in `config_a` is replaced with `b.b2=1` in `config_c.py`.
-#### Inherit from base config with ignored fields
+### Inherit from base config with ignored fields
`config_d.py`
@@ -130,7 +128,7 @@ c = (1, 2)
You may also set `_delete_=True` to ignore some fields in base configs. All old keys `b1, b2, b3` in `b` are replaced with new keys `b2, b3`.
-#### Inherit from multiple base configs (the base configs should not contain the same keys)
+### Inherit from multiple base configs (the base configs should not contain the same keys)
`config_e.py`
@@ -154,74 +152,28 @@ _base_ = ['./config_a.py', './config_e.py']
... d='string')
```
-### ProgressBar
-
-If you want to apply a method to a list of items and track the progress, `track_progress`
-is a good choice. It will display a progress bar to tell the progress and ETA.
-
-```python
-import mmcv
-
-def func(item):
- # do something
- pass
+### Reference variables from base
-tasks = [item_1, item_2, ..., item_n]
+You can reference variables defined in base using the following grammar.
-mmcv.track_progress(func, tasks)
-```
-
-The output is like the following.
-![progress](_static/progress.gif)
-
-There is another method `track_parallel_progress`, which wraps multiprocessing and
-progress visualization.
+`base.py`
```python
-mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
+item1 = 'a'
+item2 = dict(item3 = 'b')
```
-![progress](_static/parallel_progress.gif)
-
-If you want to iterate or enumerate a list of items and track the progress, `track_iter_progress`
-is a good choice. It will display a progress bar to tell the progress and ETA.
+`config_g.py`
```python
-import mmcv
-
-tasks = [item_1, item_2, ..., item_n]
-
-for task in mmcv.track_iter_progress(tasks):
- # do something like print
- print(task)
-
-for i, task in enumerate(mmcv.track_iter_progress(tasks)):
- # do something like print
- print(i)
- print(task)
+_base_ = ['./base.py']
+item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
```
-### Timer
-
-It is convinient to compute the runtime of a code block with `Timer`.
-
-```python
-import time
-
-with mmcv.Timer():
- # simulate some code block
- time.sleep(1)
-```
-
-or try with `since_start()` and `since_last_check()`. This former can
-return the runtime since the timer starts and the latter will return the time
-since the last time checked.
-
```python
-timer = mmcv.Timer()
-# code block 1 here
-print(timer.since_start())
-# code block 2 here
-print(timer.since_last_check())
-print(timer.since_start())
+>>> cfg = Config.fromfile('./config_g.py')
+>>> print(cfg.pretty_text)
+item1 = 'a'
+item2 = dict(item3='b')
+item = dict(a='a', b='b')
```
diff --git a/docs/image.md b/docs/understand_mmcv/data_process.md
similarity index 55%
rename from docs/image.md
rename to docs/understand_mmcv/data_process.md
index c6e9bbef45..79e9281b6c 100644
--- a/docs/image.md
+++ b/docs/understand_mmcv/data_process.md
@@ -1,8 +1,10 @@
-## Image
+## Data Process
+
+### Image
This module provides some image processing methods, which requires `opencv` to be installed.
-### Read/Write/Show
+#### Read/Write/Show
To read or write images files, use `imread` or `imwrite`.
@@ -11,7 +13,7 @@ import mmcv
img = mmcv.imread('test.jpg')
img = mmcv.imread('test.jpg', flag='grayscale')
-img_ = mmcv.imread(img) # nothing will happen, img_ = img
+img_ = mmcv.imread(img) # nothing will happen, img_ = img
mmcv.imwrite(img, 'out.jpg')
```
@@ -34,7 +36,7 @@ for i in range(10):
mmcv.imshow(img, win_name='test image', wait_time=200)
```
-### Color space conversion
+#### Color space conversion
Supported conversion methods:
@@ -52,7 +54,7 @@ img2 = mmcv.rgb2gray(img1)
img3 = mmcv.bgr2hsv(img)
```
-### Resize
+#### Resize
There are three resize methods. All `imresize_*` methods have an argument `return_scale`,
if this argument is `False`, then the return value is merely the resized image, otherwise
@@ -73,7 +75,7 @@ mmcv.imrescale(img, 0.5)
mmcv.imrescale(img, (1000, 800))
```
-### Rotate
+#### Rotate
To rotate an image by some angle, use `imrotate`. The center can be specified,
which is the center of original image by default. There are two modes of rotating,
@@ -100,7 +102,7 @@ img_ = mmcv.imrotate(img, 30, center=(100, 100))
img_ = mmcv.imrotate(img, 30, auto_bound=True)
```
-### Flip
+#### Flip
To flip an image, use `imflip`.
@@ -114,7 +116,7 @@ mmcv.imflip(img)
mmcv.imflip(img, direction='vertical')
```
-### Crop
+#### Crop
`imcrop` can crop the image with one or some regions, represented as (x1, y1, x2, y2).
@@ -136,7 +138,7 @@ patches = mmcv.imcrop(img, bboxes)
patches = mmcv.imcrop(img, bboxes, scale_ratio=1.2)
```
-### Padding
+#### Padding
There are two methods `impad` and `impad_to_multiple` to pad an image to the
specific size with given values.
@@ -160,3 +162,125 @@ img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=[100, 50, 200])
# pad an image so that each edge is a multiple of some value.
img_ = mmcv.impad_to_multiple(img, 32)
```
+
+### Video
+
+This module provides the following functionalities.
+
+- A `VideoReader` class with friendly apis to read and convert videos.
+- Some methods for editing (cut, concat, resize) videos.
+- Optical flow read/write/warp.
+
+#### VideoReader
+
+The `VideoReader` class provides sequence like apis to access video frames.
+It will internally cache the frames which have been visited.
+
+```python
+video = mmcv.VideoReader('test.mp4')
+
+# obtain basic information
+print(len(video))
+print(video.width, video.height, video.resolution, video.fps)
+
+# iterate over all frames
+for frame in video:
+ print(frame.shape)
+
+# read the next frame
+img = video.read()
+
+# read a frame by index
+img = video[100]
+
+# read some frames
+img = video[5:10]
+```
+
+To convert a video to images or generate a video from a image directory.
+
+```python
+# split a video into frames and save to a folder
+video = mmcv.VideoReader('test.mp4')
+video.cvt2frames('out_dir')
+
+# generate video from frames
+mmcv.frames2video('out_dir', 'test.avi')
+```
+
+#### Editing utils
+
+There are also some methods for editing videos, which wraps the commands of ffmpeg.
+
+```python
+# cut a video clip
+mmcv.cut_video('test.mp4', 'clip1.mp4', start=3, end=10, vcodec='h264')
+
+# join a list of video clips
+mmcv.concat_video(['clip1.mp4', 'clip2.mp4'], 'joined.mp4', log_level='quiet')
+
+# resize a video with the specified size
+mmcv.resize_video('test.mp4', 'resized1.mp4', (360, 240))
+
+# resize a video with a scaling ratio of 2
+mmcv.resize_video('test.mp4', 'resized2.mp4', ratio=2)
+```
+
+#### Optical flow
+
+`mmcv` provides the following methods to operate on optical flows.
+
+- IO
+- Visualization
+- Flow warpping
+
+We provide two options to dump optical flow files: uncompressed and compressed.
+The uncompressed way just dumps the floating numbers to a binary file. It is
+lossless but the dumped file has a larger size.
+The compressed way quantizes the optical flow to 0-255 and dumps it as a
+jpeg image. The flow of x-dim and y-dim will be concatenated into a single image.
+
+1. IO
+
+```python
+flow = np.random.rand(800, 600, 2).astype(np.float32)
+# dump the flow to a flo file (~3.7M)
+mmcv.flowwrite(flow, 'uncompressed.flo')
+# dump the flow to a jpeg file (~230K)
+# the shape of the dumped image is (800, 1200)
+mmcv.flowwrite(flow, 'compressed.jpg', quantize=True, concat_axis=1)
+
+# read the flow file, the shape of loaded flow is (800, 600, 2) for both ways
+flow = mmcv.flowread('uncompressed.flo')
+flow = mmcv.flowread('compressed.jpg', quantize=True, concat_axis=1)
+```
+
+2. Visualization
+
+It is possible to visualize optical flows with `mmcv.flowshow()`.
+
+```python
+mmcv.flowshow(flow)
+```
+
+![progress](../_static/flow_visualization.png)
+
+3. Flow warpping
+
+```python
+img1 = mmcv.imread('img1.jpg')
+flow = mmcv.flowread('flow.flo')
+warpped_img2 = mmcv.flow_warp(img1, flow)
+```
+
+img1 (left) and img2 (right)
+
+![raw images](../_static/flow_raw_images.png)
+
+optical flow (img2 -> img1)
+
+![optical flow](../_static/flow_img2toimg1.png)
+
+warpped image and difference with ground truth
+
+![warpped image](../_static/flow_warp_diff.png)
diff --git a/docs/io.md b/docs/understand_mmcv/io.md
similarity index 96%
rename from docs/io.md
rename to docs/understand_mmcv/io.md
index c1cef2ab12..50314d13d0 100644
--- a/docs/io.md
+++ b/docs/understand_mmcv/io.md
@@ -105,7 +105,7 @@ Then use `list_from_file` to load the list from a.txt.
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
```
-For example `b.txt` is a text file with 5 lines.
+For example `b.txt` is a text file with 3 lines.
```
1 cat
@@ -113,7 +113,7 @@ For example `b.txt` is a text file with 5 lines.
3 panda
```
-Then use `dict_from_file` to load the list from a.txt.
+Then use `dict_from_file` to load the dict from `b.txt` .
```python
>>> mmcv.dict_from_file('b.txt')
diff --git a/docs/ops.md b/docs/understand_mmcv/ops.md
similarity index 100%
rename from docs/ops.md
rename to docs/understand_mmcv/ops.md
diff --git a/docs/registry.md b/docs/understand_mmcv/registry.md
similarity index 99%
rename from docs/registry.md
rename to docs/understand_mmcv/registry.md
index 3793224b6d..242a962a20 100644
--- a/docs/registry.md
+++ b/docs/understand_mmcv/registry.md
@@ -62,7 +62,7 @@ converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = CONVERTERS.build(converter_cfg)
```
-## Customize Build Function
+### Customize Build Function
Suppose we would like to customize how `converters` are built, we could implement a customized `build_func` and pass it into the registry.
@@ -89,7 +89,7 @@ Note: in this example, we demonstrate how to use the `build_func` argument to cu
The functionality is similar to the default `build_from_cfg`. In most cases, default one would be sufficient.
`build_model_from_cfg` is also implemented to build PyTorch module in `nn.Sequentail`, you may directly use them instead of implementing by yourself.
-## Hierarchy Registry
+### Hierarchy Registry
You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in [MMClassification](https://github.com/open-mmlab/mmclassification) for object detectors in [MMDetection](https://github.com/open-mmlab/mmdetection), you may also combine an object detection model in [MMDetection](https://github.com/open-mmlab/mmdetection) and semantic segmentation model in [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).
diff --git a/docs/understand_mmcv/runner.md b/docs/understand_mmcv/runner.md
new file mode 100644
index 0000000000..8cf0385709
--- /dev/null
+++ b/docs/understand_mmcv/runner.md
@@ -0,0 +1,163 @@
+## Runner
+
+The runner class is designed to manage the training. It eases the training process with less code demanded from users while staying flexible and configurable. The main features are as listed:
+
+- Support `EpochBasedRunner` and `IterBasedRunner` for different scenarios. Implementing customized runners is also allowed to meet customized needs.
+- Support customized workflow to allow switching between different modes while training. Currently, supported modes are train and val.
+- Enable extensibility through various hooks, including hooks defined in MMCV and customized ones.
+
+### EpochBasedRunner
+
+As its name indicates, workflow in `EpochBasedRunner` should be set based on epochs. For example, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. And each epoch may contain multiple iterations. Currently, MMDetection uses `EpochBasedRunner` by default.
+
+Let's take a look at its core logic:
+
+```python
+# the condition to stop training
+while curr_epoch < max_epochs:
+ # traverse the workflow.
+ # e.g. workflow = [('train', 2), ('val', 1)]
+ for i, flow in enumerate(workflow):
+ # mode(e.g. train) determines which function to run
+ mode, epochs = flow
+ # epoch_runner will be either self.train() or self.val()
+ epoch_runner = getattr(self, mode)
+ # execute the corresponding function
+ for _ in range(epochs):
+ epoch_runner(data_loaders[i], **kwargs)
+```
+
+Currently, we support 2 modes: train and val. Let's take a train function for example and have a look at its core logic:
+
+```python
+# Currently, epoch_runner could be either train or val
+def train(self, data_loader, **kwargs):
+ # traverse the dataset and get batch data for 1 epoch
+ for i, data_batch in enumerate(data_loader):
+ # it will execute all before_train_iter function in the hooks registered. You may want to watch out for the order.
+ self.call_hook('before_train_iter')
+ # set train_mode as False in val function
+ self.run_iter(data_batch, train_mode=True, **kwargs)
+ self.call_hook('after_train_iter')
+ self.call_hook('after_train_epoch')
+```
+
+### IterBasedRunner
+
+Different from `EpochBasedRunner`, workflow in `IterBasedRunner` should be set based on iterations. For example, [('train', 2), ('val', 1)] means running 2 iters for training and 1 iter for validation, iteratively. Currently, MMSegmentation uses `IterBasedRunner` by default.
+
+Let's take a look at its core logic:
+
+```python
+# Although we set workflow by iters here, we might also need info on the epochs in some using cases. That can be provided by IterLoader.
+iter_loaders = [IterLoader(x) for x in data_loaders]
+# the condition to stop training
+while curr_iter < max_iters:
+ # traverse the workflow.
+ # e.g. workflow = [('train', 2), ('val', 1)]
+ for i, flow in enumerate(workflow):
+ # mode(e.g. train) determines which function to run
+ mode, iters = flow
+ # epoch_runner will be either self.train() or self.val()
+ iter_runner = getattr(self, mode)
+ # execute the corresponding function
+ for _ in range(iters):
+ iter_runner(iter_loaders[i], **kwargs)
+```
+
+Currently, we support 2 modes: train and val. Let's take a val function for example and have a look at its core logic:
+
+```python
+# Currently, iter_runner could be either train or val
+def val(self, data_loader, **kwargs):
+ # get batch data for 1 iter
+ data_batch = next(data_loader)
+ # it will execute all before_val_iter function in the hooks registered. You may want to watch out for the order.
+ self.call_hook('before_val_iter')
+ outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+ self.outputs = outputs
+ self.call_hook('after_val_iter')
+```
+
+Other than the basic functionalities explained above, `EpochBasedRunner` and `IterBasedRunner` provide methods such as `resume`, `save_checkpoint` and `register_hook`. In case you are not familiar with the term Hook mentioned earlier, we will also provide a tutorial about it.(coming soon...) Essentially, a hook is functionality to alter or augment the code behaviors through predefined api. It allows users to have their own code called under certain circumstances. It makes code extensible in a non-intrusive manner.
+
+### A Simple Example
+
+We will walk you through the usage of runner with a classification task. The following code only contains essential steps for demonstration purposes. The following steps are necessary for any training tasks.
+
+**(1) Initialize dataloader, model, optimizer, etc.**
+
+```python
+# initialize model
+model=...
+# initialize optimizer, typically, we set: cfg.optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
+optimizer = build_optimizer(model, cfg.optimizer)
+# intialize the dataloader corresponding to the workflow(train/val)
+data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ ...) for ds in dataset
+ ]
+```
+
+**(2) Initialize runner**
+
+```python
+runner = build_runner(
+ # cfg.runner is typically set as:
+ # runner = dict(type='EpochBasedRunner', max_epochs=200)
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ logger=logger))
+```
+
+**(3) Register training hooks and customized hooks.**
+
+```python
+# register defalt hooks neccesary for traning
+runner.register_training_hooks(
+ # configs of learning rate,it is typically set as:
+ # lr_config = dict(policy='step', step=[100, 150])
+ cfg.lr_config,
+ # configuration of optimizer, e.g. grad_clip
+ optimizer_config,
+ # configuration of saving checkpoints, it is typically set as:
+ # checkpoint_config = dict(interval=1),saving checkpoints every epochs
+ cfg.checkpoint_config,
+ # configuration of logs
+ cfg.log_config,
+ ...)
+
+# register customized hooks
+# say we want to enable ema, then we could set custom_hooks=[dict(type='EMAHook')]
+if cfg.get('custom_hooks', None):
+ custom_hooks = cfg.custom_hooks
+ for hook_cfg in cfg.custom_hooks:
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = build_from_cfg(hook_cfg, HOOKS)
+ runner.register_hook(hook, priority=priority)
+```
+
+Then, we can use `resume` or `load_checkpoint` to load existing weights.
+
+**(4) Start training**
+
+```python
+# workflow is typically set as: workflow = [('train', 1)]
+# here the training begins.
+runner.run(data_loaders, cfg.workflow)
+```
+
+Let's take `EpochBasedRunner` for example and go a little bit into details about setting workflow:
+
+- Say we only want to put train in the workflow, then we can set: workflow = [('train', 1)]. The runner will only execute train iteratively in this case.
+- Say we want to put both train and val in the workflow, then we can set: workflow = [('train', 3), ('val',1)]. The runner will first execute train for 3 epochs and then switch to val mode and execute val for 1 epoch. The workflow will be repeated until the current epoch hit the max_epochs.
+- Workflow is highly flexible. Therefore, you can set workflow = [('val', 1), ('train',1)] if you would like the runner to validate first and train after.
+
+The code we demonstrated above is already in `train.py` in MM repositories. Simply modify the corresponding keys in the configuration files and the script will execute the expected workflow automatically.
diff --git a/docs/understand_mmcv/utils.md b/docs/understand_mmcv/utils.md
new file mode 100644
index 0000000000..6936688b3b
--- /dev/null
+++ b/docs/understand_mmcv/utils.md
@@ -0,0 +1,73 @@
+## Utils
+
+### ProgressBar
+
+If you want to apply a method to a list of items and track the progress, `track_progress`
+is a good choice. It will display a progress bar to tell the progress and ETA.
+
+```python
+import mmcv
+
+def func(item):
+ # do something
+ pass
+
+tasks = [item_1, item_2, ..., item_n]
+
+mmcv.track_progress(func, tasks)
+```
+
+The output is like the following.
+![progress](../_static/progress.gif)
+
+There is another method `track_parallel_progress`, which wraps multiprocessing and
+progress visualization.
+
+```python
+mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
+```
+
+![progress](../_static/parallel_progress.gif)
+
+If you want to iterate or enumerate a list of items and track the progress, `track_iter_progress`
+is a good choice. It will display a progress bar to tell the progress and ETA.
+
+```python
+import mmcv
+
+tasks = [item_1, item_2, ..., item_n]
+
+for task in mmcv.track_iter_progress(tasks):
+ # do something like print
+ print(task)
+
+for i, task in enumerate(mmcv.track_iter_progress(tasks)):
+ # do something like print
+ print(i)
+ print(task)
+```
+
+### Timer
+
+It is convenient to compute the runtime of a code block with `Timer`.
+
+```python
+import time
+
+with mmcv.Timer():
+ # simulate some code block
+ time.sleep(1)
+```
+
+or try with `since_start()` and `since_last_check()`. This former can
+return the runtime since the timer starts and the latter will return the time
+since the last time checked.
+
+```python
+timer = mmcv.Timer()
+# code block 1 here
+print(timer.since_start())
+# code block 2 here
+print(timer.since_last_check())
+print(timer.since_start())
+```
diff --git a/docs/visualization.md b/docs/understand_mmcv/visualization.md
similarity index 100%
rename from docs/visualization.md
rename to docs/understand_mmcv/visualization.md
diff --git a/docs/video.md b/docs/video.md
deleted file mode 100644
index a01f377164..0000000000
--- a/docs/video.md
+++ /dev/null
@@ -1,117 +0,0 @@
-## Video
-
-This module provides the following functionalities.
-
-- A `VideoReader` class with friendly apis to read and convert videos.
-- Some methods for editing (cut, concat, resize) videos.
-- Optical flow read/write/warp.
-
-### VideoReader
-
-The `VideoReader` class provides sequence like apis to access video frames.
-It will internally cache the frames which have been visited.
-
-```python
-video = mmcv.VideoReader('test.mp4')
-
-# obtain basic information
-print(len(video))
-print(video.width, video.height, video.resolution, video.fps)
-
-# iterate over all frames
-for frame in video:
- print(frame.shape)
-
-# read the next frame
-img = video.read()
-
-# read a frame by index
-img = video[100]
-
-# read some frames
-img = video[5:10]
-```
-
-To convert a video to images or generate a video from a image directory.
-
-```python
-# split a video into frames and save to a folder
-video = mmcv.VideoReader('test.mp4')
-video.cvt2frames('out_dir')
-
-# generate video from frames
-mmcv.frames2video('out_dir', 'test.avi')
-```
-
-### Editing utils
-
-There are also some methods for editing videos, which wraps the commands of ffmpeg.
-
-```python
-# cut a video clip
-mmcv.cut_video('test.mp4', 'clip1.mp4', start=3, end=10, vcodec='h264')
-
-# join a list of video clips
-mmcv.concat_video(['clip1.mp4', 'clip2.mp4'], 'joined.mp4', log_level='quiet')
-
-# resize a video with the specified size
-mmcv.resize_video('test.mp4', 'resized1.mp4', (360, 240))
-
-# resize a video with a scaling ratio of 2
-mmcv.resize_video('test.mp4', 'resized2.mp4', ratio=2)
-```
-
-### Optical flow
-
-`mmcv` provides the following methods to operate on optical flows.
-
-- IO
-- Visualization
-- Flow warpping
-
-We provide two options to dump optical flow files: uncompressed and compressed.
-The uncompressed way just dumps the floating numbers to a binary file. It is
-lossless but the dumped file has a larger size.
-The compressed way quantizes the optical flow to 0-255 and dumps it as a
-jpeg image. The flow of x-dim and y-dim will be concatenated into a single image.
-
-```python
-flow = np.random.rand(800, 600, 2).astype(np.float32)
-# dump the flow to a flo file (~3.7M)
-mmcv.flowwrite(flow, 'uncompressed.flo')
-# dump the flow to a jpeg file (~230K)
-# the shape of the dumped image is (800, 1200)
-mmcv.flowwrite(flow, 'compressed.jpg', quantize=True, concat_axis=1)
-
-# read the flow file, the shape of loaded flow is (800, 600, 2) for both ways
-flow = mmcv.flowread('uncompressed.flo')
-flow = mmcv.flowread('compressed.jpg', quantize=True, concat_axis=1)
-```
-
-It is possible to visualize optical flows with `mmcv.flowshow()`.
-
-```python
-mmcv.flowshow(flow)
-```
-
-![progress](_static/flow_visualization.png)
-
-3. Flow warpping
-
-```python
-img1 = mmcv.imread('img1.jpg')
-flow = mmcv.flowread('flow.flo')
-warpped_img2 = mmcv.flow_warp(img1, flow)
-```
-
-img1 (left) and img2 (right)
-
-![raw images](_static/flow_raw_images.png)
-
-optical flow (img2 -> img1)
-
-![optical flow](_static/flow_img2toimg1.png)
-
-warpped image and difference with ground truth
-
-![warpped image](_static/flow_warp_diff.png)
diff --git a/docs_zh_CN/Makefile b/docs_zh_CN/Makefile
new file mode 100644
index 0000000000..51285967a7
--- /dev/null
+++ b/docs_zh_CN/Makefile
@@ -0,0 +1,19 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS =
+SPHINXBUILD = sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs_zh_CN/_static b/docs_zh_CN/_static
new file mode 120000
index 0000000000..ead5849d0e
--- /dev/null
+++ b/docs_zh_CN/_static
@@ -0,0 +1 @@
+../docs/_static
\ No newline at end of file
diff --git a/docs_zh_CN/api.rst b/docs_zh_CN/api.rst
new file mode 100644
index 0000000000..fb77ebaa94
--- /dev/null
+++ b/docs_zh_CN/api.rst
@@ -0,0 +1,48 @@
+API 文档
+=========
+
+
+fileio
+-------
+.. automodule:: mmcv.fileio
+ :members:
+
+image
+------
+.. automodule:: mmcv.image
+ :members:
+
+video
+------
+.. automodule:: mmcv.video
+ :members:
+
+arraymisc
+---------
+.. automodule:: mmcv.arraymisc
+ :members:
+
+visualization
+--------------
+.. automodule:: mmcv.visualization
+ :members:
+
+utils
+-----
+.. automodule:: mmcv.utils
+ :members:
+
+cnn
+----
+.. automodule:: mmcv.cnn
+ :members:
+
+runner
+------
+.. automodule:: mmcv.runner
+ :members:
+
+ops
+------
+.. automodule:: mmcv.ops
+ :members:
diff --git a/docs_zh_CN/community.rst b/docs_zh_CN/community.rst
new file mode 100644
index 0000000000..6ff519a7b0
--- /dev/null
+++ b/docs_zh_CN/community.rst
@@ -0,0 +1,7 @@
+社区
+===========
+
+.. toctree::
+ :maxdepth: 2
+
+ community/contributing.md
diff --git a/docs_zh_CN/community/contributing.md b/docs_zh_CN/community/contributing.md
new file mode 100644
index 0000000000..51df51aedf
--- /dev/null
+++ b/docs_zh_CN/community/contributing.md
@@ -0,0 +1,3 @@
+## 贡献代码
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/conf.py b/docs_zh_CN/conf.py
new file mode 100644
index 0000000000..ab1db4a5c7
--- /dev/null
+++ b/docs_zh_CN/conf.py
@@ -0,0 +1,195 @@
+#
+# Configuration file for the Sphinx documentation builder.
+#
+# This file does only contain a selection of the most common options. For a
+# full list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+from m2r import MdInclude
+from recommonmark.transform import AutoStructify
+
+sys.path.insert(0, os.path.abspath('..'))
+
+version_file = '../mmcv/version.py'
+with open(version_file, 'r') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+__version__ = locals()['__version__']
+
+# -- Project information -----------------------------------------------------
+
+project = 'mmcv'
+copyright = '2018-2021, OpenMMLab'
+author = 'MMCV Authors'
+
+# The short X.Y version
+version = __version__
+# The full version, including alpha/beta/rc tags
+release = __version__
+
+# -- General configuration ---------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+# needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.viewcode',
+ 'recommonmark',
+ 'sphinx.ext.autosectionlabel',
+ 'sphinx_markdown_tables'
+] # yapf: disable
+
+autodoc_mock_imports = ['mmcv._ext', 'mmcv.utils.ext_loader', 'torchvision']
+autosectionlabel_prefix_document = True
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+source_suffix = {
+ '.rst': 'restructuredtext',
+ '.md': 'markdown',
+}
+
+# The master toctree document.
+master_doc = 'index'
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = 'zh_CN'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = 'sphinx'
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+#
+# html_theme_options = {}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+
+# Custom sidebar templates, must be a dictionary that maps document names
+# to template names.
+#
+# The default sidebars (for documents that don't match any pattern) are
+# defined by theme itself. Builtin themes are using these templates by
+# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
+# 'searchbox.html']``.
+#
+# html_sidebars = {}
+
+# -- Options for HTMLHelp output ---------------------------------------------
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = 'mmcvdoc'
+
+# -- Options for LaTeX output ------------------------------------------------
+
+latex_elements = {
+ # The paper size ('letterpaper' or 'a4paper').
+ #
+ # 'papersize': 'letterpaper',
+
+ # The font size ('10pt', '11pt' or '12pt').
+ #
+ # 'pointsize': '10pt',
+
+ # Additional stuff for the LaTeX preamble.
+ #
+ # 'preamble': '',
+
+ # Latex figure (float) alignment
+ #
+ # 'figure_align': 'htbp',
+}
+
+# Grouping the document tree into LaTeX files. List of tuples
+# (source start file, target name, title,
+# author, documentclass [howto, manual, or own class]).
+latex_documents = [
+ (master_doc, 'mmcv.tex', 'mmcv Documentation', 'Kai Chen', 'manual'),
+]
+
+# -- Options for manual page output ------------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [(master_doc, 'mmcv', 'mmcv Documentation', [author], 1)]
+
+# -- Options for Texinfo output ----------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (master_doc, 'mmcv', 'mmcv Documentation', author, 'mmcv',
+ 'One line description of project.', 'Miscellaneous'),
+]
+
+# -- Options for Epub output -------------------------------------------------
+
+# Bibliographic Dublin Core info.
+epub_title = project
+
+# The unique identifier of the text. This can be a ISBN number
+# or the project homepage.
+#
+# epub_identifier = ''
+
+# A unique identification for the text.
+#
+# epub_uid = ''
+
+# A list of files that should not be packed into the epub file.
+epub_exclude_files = ['search.html']
+
+# -- Extension configuration -------------------------------------------------
+
+
+def setup(app):
+ app.add_config_value('no_underscore_emphasis', False, 'env')
+ app.add_config_value('m2r_parse_relative_links', False, 'env')
+ app.add_config_value('m2r_anonymous_references', False, 'env')
+ app.add_config_value('m2r_disable_inline_math', False, 'env')
+ app.add_directive('mdinclude', MdInclude)
+ app.add_config_value('recommonmark_config', {
+ 'auto_toc_tree_section': 'Contents',
+ 'enable_eval_rst': True,
+ }, True)
+ app.add_transform(AutoStructify)
diff --git a/docs_zh_CN/deployment.rst b/docs_zh_CN/deployment.rst
new file mode 100644
index 0000000000..c9e150a98a
--- /dev/null
+++ b/docs_zh_CN/deployment.rst
@@ -0,0 +1,11 @@
+部署
+========
+
+.. toctree::
+ :maxdepth: 2
+
+ deployment/onnx.md
+ deployment/onnxruntime_op.md
+ deployment/onnxruntime_custom_ops.md
+ deployment/tensorrt_plugin.md
+ deployment/tensorrt_custom_ops.md
diff --git a/docs_zh_CN/deployment/onnx.md b/docs_zh_CN/deployment/onnx.md
new file mode 100644
index 0000000000..5268926d44
--- /dev/null
+++ b/docs_zh_CN/deployment/onnx.md
@@ -0,0 +1,3 @@
+# MMCV 中的 onnx 模块 (实验性质)
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/deployment/onnxruntime_custom_ops.md b/docs_zh_CN/deployment/onnxruntime_custom_ops.md
new file mode 100644
index 0000000000..5b76dfeac5
--- /dev/null
+++ b/docs_zh_CN/deployment/onnxruntime_custom_ops.md
@@ -0,0 +1,3 @@
+# Onnxruntime 自定义算子
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/deployment/onnxruntime_op.md b/docs_zh_CN/deployment/onnxruntime_op.md
new file mode 100644
index 0000000000..845f30f55b
--- /dev/null
+++ b/docs_zh_CN/deployment/onnxruntime_op.md
@@ -0,0 +1,3 @@
+# MMCV 中用于 ONNX Runtime 的自定义算子
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/deployment/tensorrt_custom_ops.md b/docs_zh_CN/deployment/tensorrt_custom_ops.md
new file mode 100644
index 0000000000..1b876e91e0
--- /dev/null
+++ b/docs_zh_CN/deployment/tensorrt_custom_ops.md
@@ -0,0 +1,3 @@
+# TensorRT 自定义算子
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/deployment/tensorrt_plugin.md b/docs_zh_CN/deployment/tensorrt_plugin.md
new file mode 100644
index 0000000000..60df06a517
--- /dev/null
+++ b/docs_zh_CN/deployment/tensorrt_plugin.md
@@ -0,0 +1,3 @@
+# MMCV 中用于自定义算子的 TensorRT 插件 (实验性质)
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/faq.md b/docs_zh_CN/faq.md
new file mode 100644
index 0000000000..4a1a21a377
--- /dev/null
+++ b/docs_zh_CN/faq.md
@@ -0,0 +1,3 @@
+## 常见问题
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/get_started.rst b/docs_zh_CN/get_started.rst
new file mode 100644
index 0000000000..6187d31ebc
--- /dev/null
+++ b/docs_zh_CN/get_started.rst
@@ -0,0 +1,9 @@
+介绍及安装
+===================
+
+.. toctree::
+ :maxdepth: 2
+
+ get_started/introduction.md
+ get_started/installation.md
+ get_started/build.md
diff --git a/docs_zh_CN/get_started/build.md b/docs_zh_CN/get_started/build.md
new file mode 100644
index 0000000000..9e1e99d404
--- /dev/null
+++ b/docs_zh_CN/get_started/build.md
@@ -0,0 +1,3 @@
+## 从源码编译 MMCV
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/get_started/installation.md b/docs_zh_CN/get_started/installation.md
new file mode 100644
index 0000000000..c9370ded87
--- /dev/null
+++ b/docs_zh_CN/get_started/installation.md
@@ -0,0 +1,3 @@
+## 安装 MMCV
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/get_started/introduction.md b/docs_zh_CN/get_started/introduction.md
new file mode 100644
index 0000000000..ad07681288
--- /dev/null
+++ b/docs_zh_CN/get_started/introduction.md
@@ -0,0 +1,3 @@
+## 介绍 MMCV
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/index.rst b/docs_zh_CN/index.rst
new file mode 100644
index 0000000000..f4a26fe924
--- /dev/null
+++ b/docs_zh_CN/index.rst
@@ -0,0 +1,21 @@
+欢迎来到 MMCV 的中文文档!
+=============================
+
+您可以在页面左下角切换中英文文档。
+
+.. toctree::
+ :maxdepth: 2
+
+ get_started.rst
+ deployment.rst
+ understand_mmcv.rst
+ api.rst
+ faq.md
+ community.rst
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`search`
diff --git a/docs_zh_CN/make.bat b/docs_zh_CN/make.bat
new file mode 100644
index 0000000000..7893348a1b
--- /dev/null
+++ b/docs_zh_CN/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+
+:end
+popd
diff --git a/docs_zh_CN/mmcv-logo.png b/docs_zh_CN/mmcv-logo.png
new file mode 120000
index 0000000000..7dcca035f6
--- /dev/null
+++ b/docs_zh_CN/mmcv-logo.png
@@ -0,0 +1 @@
+../docs/mmcv-logo.png
\ No newline at end of file
diff --git a/docs_zh_CN/understand_mmcv.rst b/docs_zh_CN/understand_mmcv.rst
new file mode 100644
index 0000000000..073ac4770b
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv.rst
@@ -0,0 +1,15 @@
+深入理解 MMCV
+=================
+
+.. toctree::
+ :maxdepth: 2
+
+ understand_mmcv/config.md
+ understand_mmcv/registry.md
+ understand_mmcv/runner.md
+ understand_mmcv/io.md
+ understand_mmcv/data_process.md
+ understand_mmcv/visualization.md
+ understand_mmcv/cnn.md
+ understand_mmcv/ops.md
+ understand_mmcv/utils.md
diff --git a/docs_zh_CN/understand_mmcv/cnn.md b/docs_zh_CN/understand_mmcv/cnn.md
new file mode 100644
index 0000000000..99dfa6cc00
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/cnn.md
@@ -0,0 +1,3 @@
+## 卷积神经网络
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/understand_mmcv/config.md b/docs_zh_CN/understand_mmcv/config.md
new file mode 100644
index 0000000000..bdbdb607f8
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/config.md
@@ -0,0 +1,3 @@
+## 配置
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/understand_mmcv/data_process.md b/docs_zh_CN/understand_mmcv/data_process.md
new file mode 100644
index 0000000000..3aab943273
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/data_process.md
@@ -0,0 +1,275 @@
+## 数据处理
+
+### 图像
+
+图像模块提供了一些图像预处理的函数,该模块依赖 `opencv` 。
+
+#### 读取/保存/显示
+
+使用 `imread` 和 `imwrite` 函数可以读取和保存图像。
+
+```python
+import mmcv
+
+img = mmcv.imread('test.jpg')
+img = mmcv.imread('test.jpg', flag='grayscale')
+img_ = mmcv.imread(img) # 相当于什么也没做
+mmcv.imwrite(img, 'out.jpg')
+```
+
+从二进制中读取图像
+
+```python
+with open('test.jpg', 'rb') as f:
+ data = f.read()
+img = mmcv.imfrombytes(data)
+```
+
+显示图像文件或已读取的图像
+
+```python
+mmcv.imshow('tests/data/color.jpg')
+
+for i in range(10):
+ img = np.random.randint(256, size=(100, 100, 3), dtype=np.uint8)
+ mmcv.imshow(img, win_name='test image', wait_time=200)
+```
+
+#### 色彩空间转换
+
+支持的转换函数:
+
+- bgr2gray
+- gray2bgr
+- bgr2rgb
+- rgb2bgr
+- bgr2hsv
+- hsv2bgr
+
+```python
+img = mmcv.imread('tests/data/color.jpg')
+img1 = mmcv.bgr2rgb(img)
+img2 = mmcv.rgb2gray(img1)
+img3 = mmcv.bgr2hsv(img)
+```
+
+#### 缩放
+
+有三种缩放图像的方法。所有以 `imresize_*` 开头的函数都有一个 `return_scale` 参数,如果
+该参数为 `False` ,函数的返回值只有调整之后的图像,否则是一个元组 `(resized_img, scale)` 。
+
+```python
+# 缩放图像至给定的尺寸
+mmcv.imresize(img, (1000, 600), return_scale=True)
+
+# 缩放图像至与给定的图像同样的尺寸
+mmcv.imresize_like(img, dst_img, return_scale=False)
+
+# 以一定的比例缩放图像
+mmcv.imrescale(img, 0.5)
+
+# 缩放图像至最长的边不大于1000、最短的边不大于800并且没有改变图像的长宽比
+mmcv.imrescale(img, (1000, 800))
+```
+
+#### 旋转
+
+我们可以使用 `imrotate` 旋转图像一定的角度。旋转的中心需要指定,默认值是原始图像的中心。有
+两种旋转的模式,一种保持图像的尺寸不变,因此旋转后原始图像中的某些部分会被裁剪,另一种是扩大
+图像的尺寸进而保留完整的原始图像。
+
+```python
+img = mmcv.imread('tests/data/color.jpg')
+
+# 顺时针旋转图像30度
+img_ = mmcv.imrotate(img, 30)
+
+# 逆时针旋转图像90度
+img_ = mmcv.imrotate(img, -90)
+
+# 顺时针旋转图像30度并且缩放图像为原始图像的1.5倍
+img_ = mmcv.imrotate(img, 30, scale=1.5)
+
+# 以坐标(100, 100)为中心顺时针旋转图像30度
+img_ = mmcv.imrotate(img, 30, center=(100, 100))
+
+# 顺时针旋转图像30度并扩大图像的尺寸
+img_ = mmcv.imrotate(img, 30, auto_bound=True)
+```
+
+#### 翻转
+
+我们可以使用 `imflip` 翻转图像。
+
+```python
+img = mmcv.imread('tests/data/color.jpg')
+
+# 水平翻转图像
+mmcv.imflip(img)
+
+# 垂直翻转图像
+mmcv.imflip(img, direction='vertical')
+```
+
+#### 裁剪
+
+`imcrop` 可以裁剪图像的一个或多个区域,每个区域用左上角和右下角坐标表示,形如(x1, y1, x2, y2)
+
+```python
+import mmcv
+import numpy as np
+
+img = mmcv.imread('tests/data/color.jpg')
+
+# 裁剪区域 (10, 10, 100, 120)
+bboxes = np.array([10, 10, 100, 120])
+patch = mmcv.imcrop(img, bboxes)
+
+# 裁剪两个区域,分别是 (10, 10, 100, 120) 和 (0, 0, 50, 50)
+bboxes = np.array([[10, 10, 100, 120], [0, 0, 50, 50]])
+patches = mmcv.imcrop(img, bboxes)
+
+# 裁剪两个区域并且缩放区域1.2倍
+patches = mmcv.imcrop(img, bboxes, scale_ratio=1.2)
+```
+
+#### 填充
+
+`impad` and `impad_to_multiple` 可以用给定的值将图像填充至给定的尺寸。
+
+```python
+img = mmcv.imread('tests/data/color.jpg')
+
+# 用给定值将图像填充至 (1000, 1200)
+img_ = mmcv.impad(img, shape=(1000, 1200), pad_val=0)
+
+# 用给定值分别填充图像的3个通道至 (1000, 1200)
+img_ = mmcv.impad(img, shape=(1000, 1200), pad_val=[100, 50, 200])
+
+# 用给定值填充图像的左、右、上、下四条边
+img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=0)
+
+# 用3个值分别填充图像的左、右、上、下四条边的3个通道
+img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=[100, 50, 200])
+
+# 将图像的四条边填充至能够被给定值整除
+img_ = mmcv.impad_to_multiple(img, 32)
+```
+
+### 视频
+
+视频模块提供了以下的功能:
+
+- 一个 `VideoReader` 类,具有友好的 API 接口可以读取和转换视频
+- 一些编辑视频的方法,包括 `cut` , `concat` , `resize`
+- 光流的读取/保存/变换
+
+#### VideoReader
+
+`VideoReader` 类提供了和序列一样的接口去获取视频帧。该类会缓存所有被访问过的帧。
+
+```python
+video = mmcv.VideoReader('test.mp4')
+
+# 获取基本的信息
+print(len(video))
+print(video.width, video.height, video.resolution, video.fps)
+
+# 遍历所有的帧
+for frame in video:
+ print(frame.shape)
+
+# 读取下一帧
+img = video.read()
+
+# 使用索引获取帧
+img = video[100]
+
+# 获取指定范围的帧
+img = video[5:10]
+```
+
+将视频切成帧并保存至给定目录或者从给定目录中生成视频。
+
+```python
+# 将视频切成帧并保存至目录
+video = mmcv.VideoReader('test.mp4')
+video.cvt2frames('out_dir')
+
+# 从给定目录中生成视频
+mmcv.frames2video('out_dir', 'test.avi')
+```
+
+#### 编辑函数
+
+有几个用于编辑视频的函数,这些函数是对 `ffmpeg` 的封装。
+
+```python
+# 裁剪视频
+mmcv.cut_video('test.mp4', 'clip1.mp4', start=3, end=10, vcodec='h264')
+
+# 将多个视频拼接成一个视频
+mmcv.concat_video(['clip1.mp4', 'clip2.mp4'], 'joined.mp4', log_level='quiet')
+
+# 将视频缩放至给定的尺寸
+mmcv.resize_video('test.mp4', 'resized1.mp4', (360, 240))
+
+# 将视频缩放至给定的倍率
+mmcv.resize_video('test.mp4', 'resized2.mp4', ratio=2)
+```
+
+#### 光流
+
+`mmcv` 提供了以下用于操作光流的函数:
+
+- 读取/保存
+- 可视化
+- 流变换
+
+我们提供了两种将光流dump到文件的方法,分别是非压缩和压缩的方法。非压缩的方法直接将浮点数值的光流
+保存至二进制文件,虽然光流无损但文件会比较大。而压缩的方法先量化光流至 0-255 整形数值再保存为
+jpeg图像。光流的x维度和y维度会被拼接到图像中。
+
+1. 读取/保存
+
+```python
+flow = np.random.rand(800, 600, 2).astype(np.float32)
+# 保存光流到flo文件 (~3.7M)
+mmcv.flowwrite(flow, 'uncompressed.flo')
+# 保存光流为jpeg图像 (~230K),图像的尺寸为 (800, 1200)
+mmcv.flowwrite(flow, 'compressed.jpg', quantize=True, concat_axis=1)
+
+# 读取光流文件,以下两种方式读取的光流尺寸均为 (800, 600, 2)
+flow = mmcv.flowread('uncompressed.flo')
+flow = mmcv.flowread('compressed.jpg', quantize=True, concat_axis=1)
+```
+
+2. 可视化
+
+使用 `mmcv.flowshow()` 可视化光流
+
+```python
+mmcv.flowshow(flow)
+```
+
+![progress](../_static/flow_visualization.png)
+
+3. 流变换
+
+```python
+img1 = mmcv.imread('img1.jpg')
+flow = mmcv.flowread('flow.flo')
+warpped_img2 = mmcv.flow_warp(img1, flow)
+```
+
+img1 (左) and img2 (右)
+
+![raw images](../_static/flow_raw_images.png)
+
+光流 (img2 -> img1)
+
+![optical flow](../_static/flow_img2toimg1.png)
+
+变换后的图像和真实图像的差异
+
+![warpped image](../_static/flow_warp_diff.png)
diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md
new file mode 100644
index 0000000000..8d3844f77c
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/io.md
@@ -0,0 +1,119 @@
+## 文件输入输出
+
+文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。
+
+### 读取和保存数据
+
+`mmcv` 提供了一个通用的 api 用于读取和保存数据,目前支持的格式有 json、yaml 和 pickle。
+
+```python
+import mmcv
+
+# 从文件中读取数据
+data = mmcv.load('test.json')
+data = mmcv.load('test.yaml')
+data = mmcv.load('test.pkl')
+# 从文件对象中读取数据
+with open('test.json', 'r') as f:
+ data = mmcv.load(f, file_format='json')
+
+# 将数据序列化为字符串
+json_str = mmcv.dump(data, file_format='json')
+
+# 将数据保存至文件 (根据文件名后缀反推文件类型)
+mmcv.dump(data, 'out.pkl')
+
+# 将数据保存至文件对象
+with open('test.yaml', 'w') as f:
+ data = mmcv.dump(data, f, file_format='yaml')
+```
+
+我们提供了易于拓展的方式以支持更多的文件格式。我们只需要创建一个继承自 `BaseFileHandler` 的
+文件句柄类并将其注册到 `mmcv` 中即可。句柄类至少需要重写三个方法。
+
+```python
+import mmcv
+
+# 支持为文件句柄类注册多个文件格式
+# @mmcv.register_handler(['txt', 'log'])
+@mmcv.register_handler('txt')
+class TxtHandler1(mmcv.BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return file.read()
+
+ def dump_to_fileobj(self, obj, file):
+ file.write(str(obj))
+
+ def dump_to_str(self, obj, **kwargs):
+ return str(obj)
+```
+
+举 `PickleHandler` 为例。
+
+```python
+import pickle
+
+class PickleHandler(mmcv.BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(
+ filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(
+ obj, filepath, mode='wb', **kwargs)
+```
+
+### 读取文件并返回列表或字典
+
+例如, `a.txt` 是文本文件,一共有5行内容。
+
+```
+a
+b
+c
+d
+e
+```
+
+使用 `list_from_file` 读取 `a.txt` 。
+
+```python
+>>> mmcv.list_from_file('a.txt')
+['a', 'b', 'c', 'd', 'e']
+>>> mmcv.list_from_file('a.txt', offset=2)
+['c', 'd', 'e']
+>>> mmcv.list_from_file('a.txt', max_num=2)
+['a', 'b']
+>>> mmcv.list_from_file('a.txt', prefix='/mnt/')
+['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
+```
+
+同样, `b.txt` 也是文本文件,一共有3行内容。
+
+```
+1 cat
+2 dog cow
+3 panda
+```
+
+使用 `dict_from_file` 读取 `b.txt` 。
+
+```python
+>>> mmcv.dict_from_file('b.txt')
+{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
+>>> mmcv.dict_from_file('b.txt', key_type=int)
+{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
+```
diff --git a/docs_zh_CN/understand_mmcv/ops.md b/docs_zh_CN/understand_mmcv/ops.md
new file mode 100644
index 0000000000..db8d8966da
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/ops.md
@@ -0,0 +1,3 @@
+## CUDA 算子
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/understand_mmcv/registry.md b/docs_zh_CN/understand_mmcv/registry.md
new file mode 100644
index 0000000000..4fbbcb3e7f
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/registry.md
@@ -0,0 +1,3 @@
+## 注册器
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/understand_mmcv/runner.md b/docs_zh_CN/understand_mmcv/runner.md
new file mode 100644
index 0000000000..c729c7acee
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/runner.md
@@ -0,0 +1,3 @@
+## 执行器
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/understand_mmcv/utils.md b/docs_zh_CN/understand_mmcv/utils.md
new file mode 100644
index 0000000000..7b8755a952
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/utils.md
@@ -0,0 +1,3 @@
+## 辅助函数
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/docs_zh_CN/understand_mmcv/visualization.md b/docs_zh_CN/understand_mmcv/visualization.md
new file mode 100644
index 0000000000..968631bc6e
--- /dev/null
+++ b/docs_zh_CN/understand_mmcv/visualization.md
@@ -0,0 +1,3 @@
+## 可视化
+
+欢迎有兴趣的朋友一起翻译 MMCV 文档。如有兴趣,请在 [MMCV issue](https://github.com/open-mmlab/mmcv/issues) 提 issue 确定翻译的文档。
diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py
index 71d2b69357..f7522fa784 100644
--- a/mmcv/cnn/__init__.py
+++ b/mmcv/cnn/__init__.py
@@ -15,25 +15,27 @@
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
- NormalInit, PretrainedInit, UniformInit, XavierInit,
- bias_init_with_prob, caffe2_xavier_init, constant_init,
- fuse_conv_bn, get_model_complexity_info, initialize,
- kaiming_init, normal_init, uniform_init, xavier_init)
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer
__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
- 'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
- 'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
- 'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
- 'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
- 'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
- 'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
- 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
- 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
- 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
- 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
- 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
- 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
]
diff --git a/mmcv/cnn/bricks/__init__.py b/mmcv/cnn/bricks/__init__.py
index 7f9a99c714..78da6f39a1 100644
--- a/mmcv/cnn/bricks/__init__.py
+++ b/mmcv/cnn/bricks/__init__.py
@@ -5,6 +5,7 @@
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
@@ -29,5 +30,5 @@
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
- 'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
]
diff --git a/mmcv/cnn/bricks/activation.py b/mmcv/cnn/bricks/activation.py
index f50241b192..89d54980e8 100644
--- a/mmcv/cnn/bricks/activation.py
+++ b/mmcv/cnn/bricks/activation.py
@@ -1,3 +1,5 @@
+from distutils.version import LooseVersion
+
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -70,7 +72,8 @@ def forward(self, input):
return F.gelu(input)
-if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4':
+if (TORCH_VERSION == 'parrots'
+ or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
diff --git a/mmcv/cnn/bricks/drop.py b/mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000..dd380c2162
--- /dev/null
+++ b/mmcv/cnn/bricks/drop.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+
+from mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+
+ def __init__(self, drop_prob=0.1):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+
+ def __init__(self, drop_prob=0.5, inplace=False):
+ super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg, default_args=None):
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/mmcv/cnn/bricks/generalized_attention.py b/mmcv/cnn/bricks/generalized_attention.py
index 8a779bf07d..c6e4f00d35 100644
--- a/mmcv/cnn/bricks/generalized_attention.py
+++ b/mmcv/cnn/bricks/generalized_attention.py
@@ -170,18 +170,23 @@ def get_position_embedding(self,
q_stride,
kv_stride,
device,
+ dtype,
feat_dim,
wave_length=1000):
- h_idxs = torch.linspace(0, h - 1, h).to(device)
+ # the default type of Tensor is float32, leading to type mismatch
+ # in fp16 mode. Cast it to support fp16 mode.
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
h_idxs = h_idxs.view((h, 1)) * q_stride
- w_idxs = torch.linspace(0, w - 1, w).to(device)
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
w_idxs = w_idxs.view((w, 1)) * q_stride
- h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(device)
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+ device=device, dtype=dtype)
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
- w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(device)
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+ device=device, dtype=dtype)
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
# (h, h_kv, 1)
@@ -192,9 +197,10 @@ def get_position_embedding(self,
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
w_diff *= self.position_magnitude
- feat_range = torch.arange(0, feat_dim / 4).to(device)
+ feat_range = torch.arange(0, feat_dim / 4).to(
+ device=device, dtype=dtype)
- dim_mat = torch.Tensor([wave_length]).to(device)
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
dim_mat = dim_mat.view((1, 1, -1))
@@ -234,7 +240,7 @@ def forward(self, x_input):
if self.attention_type[1] or self.attention_type[3]:
position_embed_x, position_embed_y = self.get_position_embedding(
h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
- x_input.device, self.position_embedding_dim)
+ x_input.device, x_input.dtype, self.position_embedding_dim)
# (n, num_heads, w, w_kv, dim)
position_feat_x = self.appr_geom_fc_x(position_embed_x).\
view(1, w, w_kv, num_heads, self.qk_embed_dim).\
diff --git a/mmcv/cnn/bricks/norm.py b/mmcv/cnn/bricks/norm.py
index 0035225853..88cd671f36 100644
--- a/mmcv/cnn/bricks/norm.py
+++ b/mmcv/cnn/bricks/norm.py
@@ -106,7 +106,7 @@ def build_norm_layer(cfg, num_features, postfix=''):
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
- if layer_type == 'SyncBN':
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
else:
assert 'num_groups' in cfg_
diff --git a/mmcv/cnn/bricks/registry.py b/mmcv/cnn/bricks/registry.py
index 12ced7ff6b..31c1ccc196 100644
--- a/mmcv/cnn/bricks/registry.py
+++ b/mmcv/cnn/bricks/registry.py
@@ -7,7 +7,9 @@
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
-POSITIONAL_ENCODING = Registry('Position encoding')
-ATTENTION = Registry('Attention')
-TRANSFORMER_LAYER = Registry('TransformerLayer')
-TRANSFORMER_LAYER_SEQUENCE = Registry('TransformerLayerSequence')
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py
index fb064f302d..06715cde60 100644
--- a/mmcv/cnn/bricks/transformer.py
+++ b/mmcv/cnn/bricks/transformer.py
@@ -1,19 +1,32 @@
import copy
-import math
import warnings
import torch
import torch.nn as nn
-from mmcv import ConfigDict
-from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
- constant_init, xavier_init)
-from mmcv.ops.multi_scale_deform_attn import (
- MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
-from mmcv.runner.base_module import BaseModule
+from mmcv import ConfigDict, deprecated_api_warning
+from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import build_from_cfg
-from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
- TRANSFORMER_LAYER_SEQUENCE)
+from .drop import build_dropout
+from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
+
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+try:
+ from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
+ warnings.warn(
+ ImportWarning(
+ '``MultiScaleDeformableAttention`` has been moved to '
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
+ '``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
+ 'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
+ ))
+
+except ImportError:
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+ '``mmcv.ops.multi_scale_deform_attn``, '
+ 'You should install ``mmcv-full`` if you need this module. ')
def build_positional_encoding(cfg, default_args=None):
@@ -26,6 +39,11 @@ def build_attention(cfg, default_args=None):
return build_from_cfg(cfg, ATTENTION, default_args)
+def build_feedforward_network(cfg, default_args=None):
+ """Builder for feed-forward network (FFN)."""
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+
+
def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
@@ -38,39 +56,84 @@ def build_transformer_layer_sequence(cfg, default_args=None):
@ATTENTION.register_module()
class MultiheadAttention(BaseModule):
- """A warpper for torch.nn.MultiheadAttention.
+ """A wrapper for ``torch.nn.MultiheadAttention``.
- This module implements MultiheadAttention with residual connection,
- and positional encoding used in DETR is also passed as input.
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
Args:
embed_dims (int): The embedding dimension.
- num_heads (int): Parallel attention heads. Same as
- `nn.MultiheadAttention`.
- dropout (float):w A Dropout layer on attn_output_weights. Default: 0..
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
"""
def __init__(self,
embed_dims,
num_heads,
- dropout=0.,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
init_cfg=None,
+ batch_first=False,
**kwargs):
super(MultiheadAttention, self).__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ')
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
self.embed_dims = embed_dims
self.num_heads = num_heads
- self.dropout = dropout
- self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout,
- **kwargs)
- self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+ **kwargs)
+ if self.batch_first:
+
+ def _bnc_to_nbc(forward):
+ """Because the dataflow('key', 'query', 'value') of
+ ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ embed_dims), We should adjust the shape of dataflow from
+ batch_first (batch, num_query, embed_dims) to num_query_first
+ (num_query ,batch, embed_dims), and recover ``attn_output``
+ from num_query_first to batch_first."""
+
+ def forward_wrapper(**kwargs):
+ convert_keys = ('key', 'query', 'value')
+ for key in kwargs.keys():
+ if key in convert_keys:
+ kwargs[key] = kwargs[key].transpose(0, 1)
+ attn_output, attn_output_weights = forward(**kwargs)
+ return attn_output.transpose(0, 1), attn_output_weights
+
+ return forward_wrapper
+
+ self.attn.forward = _bnc_to_nbc(self.attn.forward)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
def forward(self,
query,
key=None,
value=None,
- residual=None,
+ identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
@@ -83,15 +146,17 @@ def forward(self,
Args:
query (Tensor): The input query with shape [num_queries, bs,
- embed_dims]. Same in `nn.MultiheadAttention.forward`.
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
- embed_dims]. Same in `nn.MultiheadAttention.forward`.
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
- residual (Tensor): This tensor, with the same shape as x,
- will be used for the residual link.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
@@ -105,18 +170,21 @@ def forward(self,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
- Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ Defaults to None.
Returns:
- Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
- if residual is None:
- residual = query
+ if identity is None:
+ identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
@@ -129,238 +197,56 @@ def forward(self,
query = query + query_pos
if key_pos is not None:
key = key + key_pos
+
out = self.attn(
- query,
- key,
+ query=query,
+ key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
- return residual + self.dropout(out)
-
-
-@ATTENTION.register_module()
-class MultiScaleDeformableAttention(BaseModule):
- """An attention module used in Deformable-Detr. `Deformable DETR:
- Deformable Transformers for End-to-End Object Detection.
-
- `_.
-
- Args:
- embed_dims (int): The embedding dimension of Attention.
- Default: 256.
- num_heads (int): Parallel attention heads. Default: 64.
- num_levels (int): The number of feature map used in
- Attention. Default: 4.
- num_points (int): The number of sampling points for
- each query in each head. Default: 4.
- im2col_step (int): The step used in image_to_column.
- Default: 64.
- dropout (float): A Dropout layer on `inp_residual`.
- Default: 0..
- init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
- Default: None.
- """
-
- def __init__(self,
- embed_dims=256,
- num_heads=8,
- num_levels=4,
- num_points=4,
- im2col_step=64,
- dropout=0.1,
- norm_cfg=None,
- init_cfg=None):
- super().__init__(init_cfg)
- if embed_dims % num_heads != 0:
- raise ValueError(f'embed_dims must be divisible by num_heads, '
- f'but got {embed_dims} and {num_heads}')
- dim_per_head = embed_dims // num_heads
- self.norm_cfg = norm_cfg
- self.init_cfg = init_cfg
- self.dropout = nn.Dropout(dropout)
-
- # you'd better set dim_per_head to a power of 2
- # which is more efficient in the CUDA implementation
- def _is_power_of_2(n):
- if (not isinstance(n, int)) or (n < 0):
- raise ValueError(
- 'invalid input for _is_power_of_2: {} (type: {})'.format(
- n, type(n)))
- return (n & (n - 1) == 0) and n != 0
-
- if not _is_power_of_2(dim_per_head):
- warnings.warn(
- "You'd better set embed_dims in "
- 'MultiScaleDeformAttention to make '
- 'the dimension of each attention head a power of 2 '
- 'which is more efficient in our CUDA implementation.')
-
- self.im2col_step = im2col_step
- self.embed_dims = embed_dims
- self.num_levels = num_levels
- self.num_heads = num_heads
- self.num_points = num_points
- self.sampling_offsets = nn.Linear(
- embed_dims, num_heads * num_levels * num_points * 2)
- self.attention_weights = nn.Linear(embed_dims,
- num_heads * num_levels * num_points)
- self.value_proj = nn.Linear(embed_dims, embed_dims)
- self.output_proj = nn.Linear(embed_dims, embed_dims)
- self.init_weight()
-
- def init_weight(self):
- """Default initialization for Parameters of Module."""
- constant_init(self.sampling_offsets, 0.)
- thetas = torch.arange(
- self.num_heads,
- dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
- grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
- grid_init = (grid_init /
- grid_init.abs().max(-1, keepdim=True)[0]).view(
- self.num_heads, 1, 1,
- 2).repeat(1, self.num_levels, self.num_points, 1)
- for i in range(self.num_points):
- grid_init[:, :, i, :] *= i + 1
-
- self.sampling_offsets.bias.data = grid_init.view(-1)
- constant_init(self.attention_weights, val=0., bias=0.)
- xavier_init(self.value_proj, distribution='uniform', bias=0.)
- xavier_init(self.output_proj, distribution='uniform', bias=0.)
-
- def forward(self,
- query,
- key,
- value,
- residual=None,
- query_pos=None,
- key_padding_mask=None,
- reference_points=None,
- spatial_shapes=None,
- level_start_index=None,
- **kwargs):
- """Forward Function of MultiScaleDeformAttention.
-
- Args:
- query (Tensor): Query of Transformer with shape
- (num_query, bs, embed_dims).
- key (Tensor): The key tensor with shape
- `(num_key, bs, embed_dims)`.
- value (Tensor): The value tensor with shape
- `(num_key, bs, embed_dims)`.
- residual (Tensor): The tensor used for addition, with the
- same shape as `x`. Default None. If None, `x` will be used.
- query_pos (Tensor): The positional encoding for `query`.
- Default: None.
- key_pos (Tensor): The positional encoding for `key`. Default
- None.
- reference_points (Tensor): The normalized reference
- points with shape (bs, num_query, num_levels, 2),
- all elements is range in [0, 1], top-left (0,0),
- bottom-right (1, 1), including padding area.
- or (N, Length_{query}, num_levels, 4), add
- additional two dimensions is (w, h) to
- form reference boxes.
- key_padding_mask (Tensor): ByteTensor for `query`, with
- shape [bs, num_key].
- spatial_shapes (Tensor): Spatial shape of features in
- different level. With shape (num_levels, 2),
- last dimension represent (h, w).
- level_start_index (Tensor): The start index of each level.
- A tensor has shape (num_levels) and can be represented
- as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
-
- Returns:
- Tensor: forwarded results with shape [num_query, bs, embed_dims].
- """
-
- if key is None:
- key = query
- if value is None:
- value = key
-
- if residual is None:
- inp_residual = query
- if query_pos is not None:
- query = query + query_pos
-
- # change to (bs, num_query ,embed_dims)
- query = query.permute(1, 0, 2)
- value = value.permute(1, 0, 2)
-
- bs, num_query, _ = query.shape
- bs, num_key, _ = value.shape
- assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key
-
- value = self.value_proj(value)
- if key_padding_mask is not None:
- value = value.masked_fill(key_padding_mask[..., None], 0.0)
- value = value.view(bs, num_key, self.num_heads, -1)
- sampling_offsets = self.sampling_offsets(query).view(
- bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
- attention_weights = self.attention_weights(query).view(
- bs, num_query, self.num_heads, self.num_levels * self.num_points)
- attention_weights = attention_weights.softmax(-1)
-
- attention_weights = attention_weights.view(bs, num_query,
- self.num_heads,
- self.num_levels,
- self.num_points)
- if reference_points.shape[-1] == 2:
- offset_normalizer = torch.stack(
- [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
- sampling_locations = reference_points[:, :, None, :, None, :] \
- + sampling_offsets \
- / offset_normalizer[None, None, None, :, None, :]
- elif reference_points.shape[-1] == 4:
- sampling_locations = reference_points[:, :, None, :, None, :2] \
- + sampling_offsets / self.num_points \
- * reference_points[:, :, None, :, None, 2:] \
- * 0.5
- else:
- raise ValueError(
- f'Last dim of reference_points must be'
- f' 2 or 4, but get {reference_points.shape[-1]} instead.')
- if torch.cuda.is_available():
- output = MultiScaleDeformableAttnFunction.apply(
- value, spatial_shapes, level_start_index, sampling_locations,
- attention_weights, self.im2col_step)
- else:
- output = multi_scale_deformable_attn_pytorch(
- value, spatial_shapes, level_start_index, sampling_locations,
- attention_weights, self.im2col_step)
- output = self.output_proj(output).permute(1, 0, 2)
- # (num_query, bs ,embed_dims)
- return self.dropout(output) + inp_residual
+ return identity + self.dropout_layer(self.proj_drop(out))
+@FEEDFORWARD_NETWORK.register_module()
class FFN(BaseModule):
- """Implements feed-forward networks (FFNs) with residual connection.
+ """Implements feed-forward networks (FFNs) with identity connection.
Args:
embed_dims (int): The feature dimension. Same as
- `MultiheadAttention`.
+ `MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
- dropout (float, optional): Probability of an element to be
- zeroed. Default 0..
- add_residual (bool, optional): Whether to add the
- residual connection. Default: `True`.
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
def __init__(self,
- embed_dims,
- feedforward_channels,
+ embed_dims=256,
+ feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
- dropout=0.,
- add_residual=True,
- init_cfg=None):
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs):
super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
@@ -368,33 +254,35 @@ def __init__(self,
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
- self.dropout = dropout
self.activate = build_activation_layer(act_cfg)
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
- nn.Sequential(
+ Sequential(
Linear(in_channels, feedforward_channels), self.activate,
- nn.Dropout(dropout)))
+ nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
- self.layers = nn.Sequential(*layers)
- self.dropout = nn.Dropout(dropout)
- self.add_residual = add_residual
-
- def forward(self, x, residual=None):
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
- if not self.add_residual:
- return self.dropout(out)
- if residual is None:
- residual = x
- return residual + self.dropout(out)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module()
@@ -416,85 +304,121 @@ class BaseTransformerLayer(BaseModule):
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
- feedforward_channels (int): The hidden dimension for FFNs.
- Default: None.
- ffn_dropout (float): Probability of an element to be zeroed
- in ffn. Default 0..
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for FFN, The order of the configs in the list should be
+ consistent with corresponding ffn in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
- act_cfg (dict): The activation config for FFNs.
- Default: dict(type='ReLU')
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
- ffn_num_fcs (int): The number of fully-connected layers in FFNs.
- Default:2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
+ batch_first (bool): Key, Query and Value are shape
+ of (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
attn_cfgs=None,
- feedforward_channels=None,
- ffn_dropout=0.,
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
operation_order=None,
- act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
- ffn_num_fcs=2,
- init_cfg=None):
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+
+ deprecated_args = dict(
+ feedforward_channels='feedforward_channels',
+ ffn_dropout='ffn_drop',
+ ffn_num_fcs='num_fcs')
+ for ori_name, new_name in deprecated_args.items():
+ if ori_name in kwargs:
+ warnings.warn(
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
+ f'has been deprecated, now you should set `{new_name}` '
+ f'and other FFN related arguments '
+ f'to a dict named `ffn_cfgs`. ')
+ ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg)
+
+ self.batch_first = batch_first
+
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
- if isinstance(attn_cfgs, ConfigDict):
+ if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
- self.init_cfg = init_cfg
+
self.num_attn = num_attn
- self.feedforward_channels = feedforward_channels
- self.ffn_dropout = ffn_dropout
self.operation_order = operation_order
- self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
- self.ffn_num_fcs = ffn_num_fcs
self.pre_norm = operation_order[0] == 'norm'
- self.attentions = nn.ModuleList()
+ self.attentions = ModuleList()
index = 0
- for operation in operation_order:
- if operation in ['self_attn', 'cross_attn']:
+ for operation_name in operation_order:
+ if operation_name in ['self_attn', 'cross_attn']:
+ if 'batch_first' in attn_cfgs[index]:
+ assert self.batch_first == attn_cfgs[index]['batch_first']
+ else:
+ attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
+ # Some custom attentions used as `self_attn`
+ # or `cross_attn` can have different behavior.
+ attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
- self.ffns = nn.ModuleList()
+
+ self.ffns = ModuleList()
num_ffns = operation_order.count('ffn')
- for _ in range(num_ffns):
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = ConfigDict(ffn_cfgs)
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+ assert len(ffn_cfgs) == num_ffns
+ for ffn_index in range(num_ffns):
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
+ ffn_cfgs['embed_dims'] = self.embed_dims
+ else:
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
- FFN(self.embed_dims, feedforward_channels, ffn_num_fcs,
- act_cfg, ffn_dropout))
+ build_feedforward_network(ffn_cfgs[ffn_index],
+ dict(type='FFN')))
- self.norms = nn.ModuleList()
+ self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
query,
- key,
- value,
+ key=None,
+ value=None,
query_pos=None,
key_pos=None,
attn_masks=None,
@@ -506,12 +430,14 @@ def forward(self,
**kwargs contains some specific arguments of attentions.
Args:
- query (Tensor): Input query with the shape
- `(num_queries, bs, embed_dims)`.
- key (Tensor): The key tensor with shape
- `(num_keys, bs, embed_dims)`.
- value (Tensor): The value tensor with shape
- `(num_keys, bs, embed_dims)`.
+ query (Tensor): The input query with shape
+ [num_queries, bs, embed_dims] if
+ self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
@@ -533,7 +459,7 @@ def forward(self,
norm_index = 0
attn_index = 0
ffn_index = 0
- inp_residual = query
+ identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
@@ -555,14 +481,14 @@ def forward(self,
query,
temp_key,
temp_value,
- inp_residual if self.pre_norm else None,
+ identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
attn_index += 1
- inp_residual = query
+ identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
@@ -573,18 +499,18 @@ def forward(self,
query,
key,
value,
- inp_residual if self.pre_norm else None,
+ identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
- inp_residual = query
+ identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
- query, inp_residual if self.pre_norm else None)
+ query, identity if self.pre_norm else None)
ffn_index += 1
return query
@@ -612,7 +538,7 @@ class TransformerLayerSequence(BaseModule):
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg)
- if isinstance(transformerlayers, ConfigDict):
+ if isinstance(transformerlayers, dict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
]
@@ -620,13 +546,11 @@ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
assert isinstance(transformerlayers, list) and \
len(transformerlayers) == num_layers
self.num_layers = num_layers
- operation_order = transformerlayers[0]['operation_order']
- self.pre_norm = operation_order[0] == 'norm'
- self.layers = nn.ModuleList()
+ self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(build_transformer_layer(transformerlayers[i]))
self.embed_dims = self.layers[0].embed_dims
- self.pre_norm = self.layers[0].operation_order[0] == 'norm'
+ self.pre_norm = self.layers[0].pre_norm
def forward(self,
query,
@@ -661,7 +585,7 @@ def forward(self,
shape [bs, num_keys]. Default: None.
Returns:
- Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ Tensor: results with shape [num_queries, bs, embed_dims].
"""
for layer in self.layers:
query = layer(
diff --git a/mmcv/cnn/bricks/wrappers.py b/mmcv/cnn/bricks/wrappers.py
index a464f86dc1..6e125b41ca 100644
--- a/mmcv/cnn/bricks/wrappers.py
+++ b/mmcv/cnn/bricks/wrappers.py
@@ -128,8 +128,8 @@ def forward(self, x):
class MaxPool2d(nn.MaxPool2d):
def forward(self, x):
- # PyTorch 1.7 does not support empty tensor inference yet
- if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride),
@@ -146,8 +146,8 @@ def forward(self, x):
class MaxPool3d(nn.MaxPool3d):
def forward(self, x):
- # PyTorch 1.7 does not support empty tensor inference yet
- if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
diff --git a/mmcv/cnn/utils/__init__.py b/mmcv/cnn/utils/__init__.py
index 18efa4135f..c8a4bd51f8 100644
--- a/mmcv/cnn/utils/__init__.py
+++ b/mmcv/cnn/utils/__init__.py
@@ -2,15 +2,17 @@
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
- KaimingInit, NormalInit, PretrainedInit, UniformInit,
- XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
- uniform_init, xavier_init)
+ trunc_normal_init, uniform_init, xavier_init)
__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
- 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
- 'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
- 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
- 'PretrainedInit', 'Caffe2XavierInit'
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit'
]
diff --git a/mmcv/cnn/utils/flops_counter.py b/mmcv/cnn/utils/flops_counter.py
index 27aec347a5..dceeb398bf 100644
--- a/mmcv/cnn/utils/flops_counter.py
+++ b/mmcv/cnn/utils/flops_counter.py
@@ -237,7 +237,7 @@ def print_model_with_flops(model,
>>> model = ExampleModel()
>>> x = (3, 16, 16)
- to print the complexity inforamtion state for each layer, you can use
+ to print the complexity information state for each layer, you can use
>>> get_model_complexity_info(model, x)
or directly use
>>> print_model_with_flops(model, 4579784.0, 37361)
diff --git a/mmcv/cnn/utils/weight_init.py b/mmcv/cnn/utils/weight_init.py
index 6de857e73f..36303a22c3 100644
--- a/mmcv/cnn/utils/weight_init.py
+++ b/mmcv/cnn/utils/weight_init.py
@@ -1,9 +1,12 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
+import math
import warnings
import numpy as np
+import torch
import torch.nn as nn
+from torch import Tensor
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
@@ -35,6 +38,18 @@ def normal_init(module, mean=0, std=1, bias=0):
nn.init.constant_(module.bias, bias)
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
def uniform_init(module, a=0, b=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
@@ -78,12 +93,16 @@ def bias_init_with_prob(prior_prob):
return bias_init
+def _get_bases_name(m):
+ return [b.__name__ for b in m.__class__.__bases__]
+
+
class BaseInit(object):
def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False
if not isinstance(bias, (int, float)):
- raise TypeError(f'bias must be a numbel, but got a {type(bias)}')
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
if bias_prob is not None:
if not isinstance(bias_prob, float):
@@ -96,9 +115,7 @@ def __init__(self, *, bias=0, bias_prob=None, layer=None):
but got a {type(layer)}')
else:
layer = []
- warnings.warn(
- 'init_cfg without layer key, if you do not define override'
- ' key either, this init_cfg will do nothing')
+
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
@@ -112,8 +129,7 @@ class ConstantInit(BaseInit):
Args:
val (int | float): the value to fill the weights in the module with
- bias (int | float): the value to fill the bias or
- define initialization type for bias. Defaults to 0.
+ bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
@@ -131,7 +147,8 @@ def init(m):
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
- if layername in self.layer:
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias)
module.apply(init)
@@ -146,8 +163,7 @@ class XavierInit(BaseInit):
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
- bias (int | float): the value to fill the bias or define
- initialization type for bias. Defaults to 0.
+ bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
@@ -168,7 +184,8 @@ def init(m):
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
- if layername in self.layer:
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
@@ -183,8 +200,7 @@ class NormalInit(BaseInit):
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
- bias (int | float): the value to fill the bias or define
- initialization type for bias. Defaults to 0.
+ bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
@@ -204,9 +220,57 @@ def init(m):
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
- for layer_ in self.layer:
- if layername == layer_:
- normal_init(m, self.mean, self.std, self.bias)
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+
+ module.apply(init)
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
module.apply(init)
@@ -221,8 +285,7 @@ class UniformInit(BaseInit):
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
- bias (int | float): the value to fill the bias or define
- initialization type for bias. Defaults to 0.
+ bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
@@ -241,7 +304,8 @@ def init(m):
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
- if layername in self.layer:
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
@@ -265,8 +329,7 @@ class KaimingInit(BaseInit):
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
- bias (int | float): the value to fill the bias or define
- initialization type for bias. Defaults to 0.
+ bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
@@ -295,7 +358,8 @@ def init(m):
self.bias, self.distribution)
else:
layername = m.__class__.__name__
- if layername in self.layer:
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
@@ -468,3 +532,68 @@ def initialize(module, init_cfg):
else:
# All attributes in module have same initialization.
pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py
index 556c4cfd71..5640029c17 100644
--- a/mmcv/fileio/parse.py
+++ b/mmcv/fileio/parse.py
@@ -1,5 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved.
-def list_from_file(filename, prefix='', offset=0, max_num=0):
+def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'):
"""Load a text file and parse the content as a list of strings.
Args:
@@ -8,19 +8,20 @@ def list_from_file(filename, prefix='', offset=0, max_num=0):
offset (int): The offset of lines.
max_num (int): The maximum number of lines to be read,
zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
Returns:
list[str]: A list of strings.
"""
cnt = 0
item_list = []
- with open(filename, 'r') as f:
+ with open(filename, 'r', encoding=encoding) as f:
for _ in range(offset):
f.readline()
for line in f:
- if max_num > 0 and cnt >= max_num:
+ if 0 < max_num <= cnt:
break
- item_list.append(prefix + line.rstrip('\n'))
+ item_list.append(prefix + line.rstrip('\n\r'))
cnt += 1
return item_list
@@ -28,13 +29,13 @@ def list_from_file(filename, prefix='', offset=0, max_num=0):
def dict_from_file(filename, key_type=str):
"""Load a text file and parse the content as a dict.
- Each line of the text file will be two or more columns splited by
+ Each line of the text file will be two or more columns split by
whitespaces or tabs. The first column will be parsed as dict keys, and
the following columns will be parsed as dict values.
Args:
filename(str): Filename.
- key_type(type): Type of the dict's keys. str is user by default and
+ key_type(type): Type of the dict keys. str is user by default and
type conversion will be performed if specified.
Returns:
diff --git a/mmcv/image/__init__.py b/mmcv/image/__init__.py
index 3f6a75f5cf..1a45f4e0c8 100644
--- a/mmcv/image/__init__.py
+++ b/mmcv/image/__init__.py
@@ -4,7 +4,8 @@
rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
from .geometric import (cutout, imcrop, imflip, imflip_, impad,
impad_to_multiple, imrescale, imresize, imresize_like,
- imrotate, imshear, imtranslate, rescale_size)
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs
from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
@@ -16,12 +17,12 @@
__all__ = [
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
- 'imresize', 'imresize_like', 'rescale_size', 'imcrop', 'imflip', 'imflip_',
- 'impad', 'impad_to_multiple', 'imrotate', 'imfrombytes', 'imread',
- 'imwrite', 'supported_backends', 'use_backend', 'imdenormalize',
- 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize',
- 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', 'tensor2imgs',
- 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
]
diff --git a/mmcv/image/geometric.py b/mmcv/image/geometric.py
index 22c458745f..f81aa4599b 100644
--- a/mmcv/image/geometric.py
+++ b/mmcv/image/geometric.py
@@ -4,6 +4,7 @@
import cv2
import numpy as np
+from ..utils import to_2tuple
from .io import imread_backend
try:
@@ -17,13 +18,15 @@ def _scale_size(size, scale):
Args:
size (tuple[int]): (w, h).
- scale (float): Scaling factor.
+ scale (float | tuple(float)): Scaling factor.
Returns:
tuple[int]: scaled size.
"""
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
w, h = size
- return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
cv2_interp_codes = {
@@ -92,6 +95,70 @@ def imresize(img,
return resized_img, w_scale, h_scale
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+
+ divisor = to_2tuple(divisor)
+ size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+
+
def imresize_like(img,
dst_img,
return_scale=False,
@@ -528,7 +595,7 @@ def _get_shear_matrix(magnitude, direction='horizontal'):
Args:
magnitude (int | float): The magnitude used for shear.
- direction (str): Thie flip direction, either "horizontal"
+ direction (str): The flip direction, either "horizontal"
or "vertical".
Returns:
@@ -552,7 +619,7 @@ def imshear(img,
img (ndarray): Image to be sheared with format (h, w)
or (h, w, c).
magnitude (int | float): The magnitude used for shear.
- direction (str): Thie flip direction, either "horizontal"
+ direction (str): The flip direction, either "horizontal"
or "vertical".
border_value (int | tuple[int]): Value used in case of a
constant border.
diff --git a/mmcv/image/io.py b/mmcv/image/io.py
index 62fe266f3e..8c64e0eff6 100644
--- a/mmcv/image/io.py
+++ b/mmcv/image/io.py
@@ -5,7 +5,8 @@
import cv2
import numpy as np
-from cv2 import IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_UNCHANGED
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+ IMREAD_UNCHANGED)
from mmcv.utils import check_file_exist, is_str, mkdir_or_exist
@@ -30,7 +31,10 @@
imread_flags = {
'color': IMREAD_COLOR,
'grayscale': IMREAD_GRAYSCALE,
- 'unchanged': IMREAD_UNCHANGED
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
}
imread_backend = 'cv2'
@@ -102,7 +106,8 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
else:
# Handle exif orientation tag
- img = ImageOps.exif_transpose(img)
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
# If the image mode is not 'RGB', convert it to 'RGB' first.
if img.mode != 'RGB':
if img.mode != 'LA':
@@ -117,17 +122,18 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
img_rgba = img.convert('RGBA')
img = Image.new('RGB', img_rgba.size, (124, 117, 104))
img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
- if flag == 'color':
+ if flag in ['color', 'color_ignore_orientation']:
array = np.array(img)
if channel_order != 'rgb':
array = array[:, :, ::-1] # RGB to BGR
- elif flag == 'grayscale':
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
img = img.convert('L')
array = np.array(img)
else:
raise ValueError(
- 'flag must be "color", "grayscale" or "unchanged", '
- f'but got {flag}')
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
return array
@@ -139,8 +145,13 @@ def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
pathlib.Path. If it is a numpy array (loaded image), then
it will be returned as is.
flag (str): Flags specifying the color type of a loaded image,
- candidates are `color`, `grayscale` and `unchanged`.
- Note that the `turbojpeg` backened does not support `unchanged`.
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
@@ -234,7 +245,7 @@ def imwrite(img, file_path, params=None, auto_mkdir=True):
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
- params (None or list): Same as opencv's :func:`imwrite` interface.
+ params (None or list): Same as opencv :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
diff --git a/mmcv/image/photometric.py b/mmcv/image/photometric.py
index b81a15b344..c43c33dd99 100644
--- a/mmcv/image/photometric.py
+++ b/mmcv/image/photometric.py
@@ -119,7 +119,7 @@ def adjust_color(img, alpha=1, beta=None, gamma=0):
beta = 1 - alpha
colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
if not colored_img.dtype == np.uint8:
- # Note when the dtype of `img` is not defaultly `np.uint8`
+ # Note when the dtype of `img` is not the default `np.uint8`
# (e.g. np.float32), the value in `colored_img` got from cv2
# is not guaranteed to be in range [0, 255], so here clip
# is needed.
@@ -320,9 +320,9 @@ def adjust_sharpness(img, factor=1., kernel=None):
# adopted from PIL.ImageFilter.SMOOTH
kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
assert isinstance(kernel, np.ndarray), \
- f'kernel must be of type np.ndarrray, but got {type(kernel)} instead.'
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
assert kernel.ndim == 2, \
- f'kernel must have a dimention of 2, but got {kernel.ndim} instead.'
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
degenerated = cv2.filter2D(img, -1, kernel)
sharpened_img = cv2.addWeighted(
@@ -340,13 +340,13 @@ def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
`_.
Args:
- img (ndarray): Image to be ajusted lighting. BGR order.
+ img (ndarray): Image to be adjusted lighting. BGR order.
eigval (ndarray): the eigenvalue of the convariance matrix of pixel
values, respectively.
eigvec (ndarray): the eigenvector of the convariance matrix of pixel
values, respectively.
alphastd (float): The standard deviation for distribution of alpha.
- Dafaults to 0.1
+ Defaults to 0.1
to_rgb (bool): Whether to convert img to rgb.
Returns:
diff --git a/mmcv/model_zoo/mmcls.json b/mmcv/model_zoo/mmcls.json
index ce9852d447..51a2a07198 100644
--- a/mmcv/model_zoo/mmcls.json
+++ b/mmcv/model_zoo/mmcls.json
@@ -1,12 +1,12 @@
{
- "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_imagenet-01ecd97e.pth",
- "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_imagenet-9ad3945d.pth",
- "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_imagenet-91b6d117.pth",
- "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_imagenet-fee352a8.pth",
- "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_imagenet-6fbbbf3f.pth",
- "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_imagenet-4b5f9390.pth",
- "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_imagenet-3ac6d8fd.pth",
- "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_imagenet-7c058385.pth",
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth",
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth",
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth",
@@ -15,10 +15,10 @@
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth",
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth",
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth",
- "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_batch256_imagenet_20200708-c07adbb7.pth",
- "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_batch256_imagenet_20200708-87f2d1c9.pth",
- "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_batch256_imagenet_20200708-1ec34aa7.pth",
- "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_batch256_imagenet_20200708-aab5034c.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
diff --git a/mmcv/model_zoo/open_mmlab.json b/mmcv/model_zoo/open_mmlab.json
index 44c24f6bfe..8311db4fee 100644
--- a/mmcv/model_zoo/open_mmlab.json
+++ b/mmcv/model_zoo/open_mmlab.json
@@ -45,5 +45,6 @@
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
- "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth"
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
}
diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py
index ed10286144..ac9987b160 100644
--- a/mmcv/ops/__init__.py
+++ b/mmcv/ops/__init__.py
@@ -1,4 +1,5 @@
from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
@@ -20,6 +21,7 @@
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
@@ -48,5 +50,6 @@
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
- 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand'
+ 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
+ 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align'
]
diff --git a/mmcv/ops/bbox.py b/mmcv/ops/bbox.py
index 06bd10e24d..855009ad14 100644
--- a/mmcv/ops/bbox.py
+++ b/mmcv/ops/bbox.py
@@ -49,7 +49,7 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
mode_dict = {'iou': 0, 'iof': 1}
assert mode in mode_dict.keys()
mode_flag = mode_dict[mode]
- # Either the boxes are empty or the length of boxes's last dimenstion is 4
+ # Either the boxes are empty or the length of boxes' last dimension is 4
assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
assert offset == 1 or offset == 0
diff --git a/mmcv/ops/border_align.py b/mmcv/ops/border_align.py
new file mode 100644
index 0000000000..e111d69550
--- /dev/null
+++ b/mmcv/ops/border_align.py
@@ -0,0 +1,108 @@
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['border_align_forward', 'border_align_backward'])
+
+
+class BorderAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, boxes, pool_size):
+ return g.op(
+ 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+
+ @staticmethod
+ def forward(ctx, input, boxes, pool_size):
+ ctx.pool_size = pool_size
+ ctx.input_shape = input.size()
+
+ assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+ assert boxes.size(2) == 4, \
+ 'the last dimension of boxes must be (x1, y1, x2, y2)'
+ assert input.size(1) % 4 == 0, \
+ 'the channel for input feature must be divisible by factor 4'
+
+ # [B, C//4, H*W, 4]
+ output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+ output = input.new_zeros(output_shape)
+ # `argmax_idx` only used for backward
+ argmax_idx = input.new_zeros(output_shape).to(torch.int)
+
+ ext_module.border_align_forward(
+ input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+
+ ctx.save_for_backward(boxes, argmax_idx)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ boxes, argmax_idx = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous
+ grad_output = grad_output.contiguous()
+ ext_module.border_align_backward(
+ grad_output,
+ boxes,
+ argmax_idx,
+ grad_input,
+ pool_size=ctx.pool_size)
+ return grad_input, None, None
+
+
+border_align = BorderAlignFunction.apply
+
+
+class BorderAlign(nn.Module):
+ r"""Border align pooling layer.
+
+ Applies border_align over the input feature based on predicted bboxes.
+ The details were described in the paper
+ `BorderDet: Border Feature for Dense Object Detection
+ `_.
+
+ For each border line (e.g. top, left, bottom or right) of each box,
+ border_align does the following:
+ 1. uniformly samples `pool_size`+1 positions on this line, involving \
+ the start and end points.
+ 2. the corresponding features on these points are computed by \
+ bilinear interpolation.
+ 3. max pooling over all the `pool_size`+1 positions are used for \
+ computing pooled feature.
+
+ Args:
+ pool_size (int): number of positions sampled over the boxes' borders
+ (e.g. top, bottom, left, right).
+
+ """
+
+ def __init__(self, pool_size):
+ super(BorderAlign, self).__init__()
+ self.pool_size = pool_size
+
+ def forward(self, input, boxes):
+ """
+ Args:
+ input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+ [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+ right features respectively.
+ boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+
+ Returns:
+ Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+ (top,left,bottom,right) for the last dimension.
+ """
+ return border_align(input, boxes, self.pool_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(pool_size={self.pool_size})'
+ return s
diff --git a/mmcv/ops/csrc/border_align_cuda_kernel.cuh b/mmcv/ops/csrc/border_align_cuda_kernel.cuh
new file mode 100644
index 0000000000..143dce5ddc
--- /dev/null
+++ b/mmcv/ops/csrc/border_align_cuda_kernel.cuh
@@ -0,0 +1,199 @@
+// modified from
+// https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/csrc/border_align/border_align_kernel.cu.
+// the main difference: (1) use `argmax_idx` for fast computing of gradient
+// during the backward. (2) `wh` is directly computed by `boxes`, rather than
+// passing it as argument to forward or backward functions.
+
+#ifndef BORDER_ALIGN_CUDA_KERNEL_CUH
+#define BORDER_ALIGN_CUDA_KERNEL_CUH
+
+#include
+#ifdef MMCV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // MMCV_WITH_TRT
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else // MMCV_USE_PARROTS
+#include "pytorch_cuda_helper.hpp"
+#endif // MMCV_USE_PARROTS
+#endif // MMCV_WITH_TRT
+
+enum BorderMode { Top = 0, Left = 1, Bottom = 2, Right = 3 };
+
+/*** Forward ***/
+template
+__global__ void border_align_forward_cuda_kernel(
+ const int nthreads, const T* input, const T* boxes, T* output,
+ int* argmax_idx, const int channels, const int box_size, const int height,
+ const int width, const int pool_size) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (batch_idx, c_idx, box_idx) is an element paralleled for computing
+ // output, and `extreme_idx` is in range [0,3]
+ int batch_idx, c_idx, box_idx, extreme_idx, maxidx, *offset_argmax_idx;
+ const T *offset_box, *offset_input, *offset_box_x;
+ T *offset_output, box_width, box_height, stride, x_stride, y_stride, x, y,
+ val, maxval;
+
+ extreme_idx = threadIdx.y;
+ // shape (N, C, box_size, 4) for output
+ batch_idx = index / channels / box_size;
+ // shape (N, box_size, 4) for boxes
+ box_idx = index % box_size + batch_idx * box_size;
+ c_idx = (index / box_size) % channels;
+
+ offset_box = boxes + box_idx * 4;
+ box_width = *(offset_box + 2) - *offset_box;
+ box_height = *(offset_box + 3) - *(offset_box + 1);
+ offset_output = output + index * 4 + extreme_idx;
+ offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
+ // shape (N, 4C, h, w) for input.
+ // [0,C) for top feature, [C,2C) for left feature,
+ // [2C,3C) for bottom feature, [3C,4C) for right feature
+ offset_input =
+ input + (batch_idx * channels * 4 + extreme_idx * channels + c_idx) *
+ height * width;
+
+ // extreme_idx in [0,1] -> offset_box_x indexed at x1
+ // extreme_idx in [2,3] -> offset_box_x indexed at x2
+ offset_box_x = offset_box + extreme_idx / 2 * 2;
+
+ // (x1,y1) or (x2,y2) for (x,y)
+ x = *offset_box_x;
+ y = *(offset_box_x + 1);
+
+ switch (extreme_idx) {
+ // top
+ case BorderMode::Top:
+ stride = box_width / pool_size;
+ x_stride = stride;
+ y_stride = 0;
+ break;
+ // left
+ case BorderMode::Left:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = stride;
+ break;
+ // bottom
+ case BorderMode::Bottom:
+ stride = box_width / pool_size;
+ x_stride = -stride;
+ y_stride = 0;
+ break;
+ // right
+ case BorderMode::Right:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = -stride;
+ break;
+ }
+
+ // initialize maxval and maxidx with the start position (e.g. (x1,y1) or
+ // (x2,y2))
+ maxval = bilinear_interpolate(offset_input, height, width, y, x, index);
+ maxidx = 0;
+
+ // do max_pool along the border
+ for (int i = 1; i <= pool_size; i++) {
+ x += x_stride;
+ y += y_stride;
+ val = bilinear_interpolate(offset_input, height, width, y, x, index);
+ if (val > maxval) {
+ maxval = val;
+ maxidx = i;
+ }
+ }
+
+ // update output and argmax_idx
+ *offset_output = maxval;
+ *offset_argmax_idx = maxidx;
+ }
+}
+
+/*** Backward ***/
+template
+__global__ void border_align_backward_cuda_kernel(
+ const int nthreads, const T* grad_output, const T* boxes,
+ const int* argmax_idx, T* grad_input, const int channels,
+ const int box_size, const int height, const int width,
+ const int pool_size) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (batch_idx, c_idx, box_idx) is an element paralleled for computing
+ // output, and `extreme_idx` is in range [0,3]
+ int batch_idx, c_idx, box_idx, extreme_idx;
+ const int* offset_argmax_idx;
+ const T *offset_grad_output, *offset_box, *offset_box_x;
+ T *offset_grad_input, box_width, box_height, stride, x_stride, y_stride, x,
+ y;
+
+ extreme_idx = threadIdx.y;
+ batch_idx = index / channels / box_size;
+ box_idx = index % box_size + batch_idx * box_size;
+ c_idx = (index / box_size) % channels;
+
+ offset_box = boxes + box_idx * 4;
+ box_width = *(offset_box + 2) - *offset_box;
+ box_height = *(offset_box + 3) - *(offset_box + 1);
+ offset_grad_output = grad_output + index * 4 + extreme_idx;
+ offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
+ // [0,C) for top feature grad, [C,2C) for left feature grad,
+ // [2C,3C) for bottom feature grad, [3C,4C) for right feature grad
+ offset_grad_input = grad_input + (batch_idx * channels * 4 +
+ extreme_idx * channels + c_idx) *
+ height * width;
+
+ // extreme_idx in [0,1] -> offset_box_x indexed at x1
+ // extreme_idx in [2,3] -> offset_box_x indexed at x2
+ offset_box_x = offset_box + extreme_idx / 2 * 2;
+
+ switch (extreme_idx) {
+ // top
+ case BorderMode::Top:
+ stride = box_width / pool_size;
+ x_stride = stride;
+ y_stride = 0;
+ break;
+ // left
+ case BorderMode::Left:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = stride;
+ break;
+ // bottom
+ case BorderMode::Bottom:
+ stride = box_width / pool_size;
+ x_stride = -stride;
+ y_stride = 0;
+ break;
+ // right
+ case BorderMode::Right:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = -stride;
+ break;
+ }
+
+ // get position (x,y) which has maximum value during forward
+ x = *offset_box_x;
+ y = *(offset_box_x + 1);
+ x += x_stride * (T)(*offset_argmax_idx);
+ y += y_stride * (T)(*offset_argmax_idx);
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, x_low,
+ x_high, y_low, y_high, index);
+
+ // update grad_output
+ atomicAdd(offset_grad_input + y_low * width + x_low,
+ *offset_grad_output * w1);
+ atomicAdd(offset_grad_input + y_low * width + x_high,
+ *offset_grad_output * w2);
+ atomicAdd(offset_grad_input + y_high * width + x_low,
+ *offset_grad_output * w3);
+ atomicAdd(offset_grad_input + y_high * width + x_high,
+ *offset_grad_output * w4);
+ }
+}
+
+#endif // BORDER_ALIGN_CUDA_KERNEL_CUH
diff --git a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh
index 0dd9c33c66..15e07d1970 100644
--- a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh
+++ b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh
@@ -14,25 +14,17 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
- int z = blockIdx.z;
-
- if (x < width && y < height && z < height + width - 1) {
- for (int batch = 0; batch < num; ++batch) {
- for (int plane = 0; plane < chn; ++plane) {
- T _t = t[(batch * chn + plane) * sp + y * width + x];
-
- if (z < width) {
- int i = z;
- T _f = f[(batch * chn + plane) * sp + y * width + i];
- weight[(batch * len + i) * sp + y * width + x] += _t * _f;
- } else {
- int i = z - width;
- int j = i < y ? i : i + 1;
-
- T _f = f[(batch * chn + plane) * sp + j * width + x];
- weight[(batch * len + width + i) * sp + y * width + x] += _t * _f;
- }
- }
+ int z = blockIdx.z % len;
+ int batch = blockIdx.z / len;
+
+ if (x < width && y < height) {
+ T *weight_ptr = weight + (batch * len + z) * sp + y * width + x;
+ const int t_offset = y * width + x;
+ const int j = (z - width < y) ? z - width : z - width + 1;
+ const int f_offset = z < width ? y * width + z : j * width + x;
+ for (int plane = 0; plane < chn; ++plane) {
+ const int tf_base = (batch * chn + plane) * sp;
+ *weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset];
}
}
}
@@ -44,23 +36,22 @@ __global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
- int plane = blockIdx.z;
-
- if (x < width && y < height && plane < chn) {
- for (int batch = 0; batch < num; ++batch) {
- for (int i = 0; i < width; ++i) {
- T _dw = dw[(batch * len + i) * sp + y * width + x];
- T _f = f[(batch * chn + plane) * sp + y * width + i];
- dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
- }
- for (int i = 0; i < height; ++i) {
- if (i == y) continue;
- int j = i < y ? i : i - 1;
-
- T _dw = dw[(batch * len + width + j) * sp + y * width + x];
- T _f = f[(batch * chn + plane) * sp + i * width + x];
- dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
- }
+ int plane = blockIdx.z % chn;
+ int batch = blockIdx.z / chn;
+
+ if (x < width && y < height) {
+ for (int i = 0; i < width; ++i) {
+ T _dw = dw[(batch * len + i) * sp + y * width + x];
+ T _f = f[(batch * chn + plane) * sp + y * width + i];
+ dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
+ }
+ for (int i = 0; i < height; ++i) {
+ if (i == y) continue;
+ int j = i < y ? i : i - 1;
+
+ T _dw = dw[(batch * len + width + j) * sp + y * width + x];
+ T _f = f[(batch * chn + plane) * sp + i * width + x];
+ dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
}
}
@@ -72,23 +63,22 @@ __global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
- int plane = blockIdx.z;
-
- if (x < width && y < height && plane < chn) {
- for (int batch = 0; batch < num; ++batch) {
- for (int i = 0; i < width; ++i) {
- T _dw = dw[(batch * len + x) * sp + y * width + i];
- T _t = t[(batch * chn + plane) * sp + y * width + i];
- df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
- }
- for (int i = 0; i < height; ++i) {
- if (i == y) continue;
- int j = i > y ? y : y - 1;
-
- T _dw = dw[(batch * len + width + j) * sp + i * width + x];
- T _t = t[(batch * chn + plane) * sp + i * width + x];
- df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
- }
+ int plane = blockIdx.z % chn;
+ int batch = blockIdx.z / chn;
+
+ if (x < width && y < height) {
+ for (int i = 0; i < width; ++i) {
+ T _dw = dw[(batch * len + x) * sp + y * width + i];
+ T _t = t[(batch * chn + plane) * sp + y * width + i];
+ df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
+ }
+ for (int i = 0; i < height; ++i) {
+ if (i == y) continue;
+ int j = i > y ? y : y - 1;
+
+ T _dw = dw[(batch * len + width + j) * sp + i * width + x];
+ T _t = t[(batch * chn + plane) * sp + i * width + x];
+ df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
}
}
@@ -100,24 +90,22 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
- int plane = blockIdx.z;
-
- if (x < width && y < height && plane < chn) {
- for (int batch = 0; batch < num; ++batch) {
- for (int i = 0; i < width; ++i) {
- T _g = g[(batch * chn + plane) * sp + y * width + i];
- T _w = weight[(batch * len + i) * sp + y * width + x];
- out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
- }
- for (int i = 0; i < height; ++i) {
- if (i == y) continue;
-
- int j = i < y ? i : i - 1;
-
- T _g = g[(batch * chn + plane) * sp + i * width + x];
- T _w = weight[(batch * len + width + j) * sp + y * width + x];
- out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
- }
+ int plane = blockIdx.z % chn;
+ int batch = blockIdx.z / chn;
+ if (x < width && y < height) {
+ for (int i = 0; i < width; ++i) {
+ T _g = g[(batch * chn + plane) * sp + y * width + i];
+ T _w = weight[(batch * len + i) * sp + y * width + x];
+ out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
+ }
+ for (int i = 0; i < height; ++i) {
+ if (i == y) continue;
+
+ int j = i < y ? i : i - 1;
+
+ T _g = g[(batch * chn + plane) * sp + i * width + x];
+ T _w = weight[(batch * len + width + j) * sp + y * width + x];
+ out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
}
}
@@ -130,25 +118,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
- int z = blockIdx.z;
-
- if (x < width && y < height && z < height + width - 1) {
- for (int batch = 0; batch < num; ++batch) {
- for (int plane = 0; plane < chn; ++plane) {
- T _dout = dout[(batch * chn + plane) * sp + y * width + x];
-
- if (z < width) {
- int i = z;
- T _g = g[(batch * chn + plane) * sp + y * width + i];
- dw[(batch * len + i) * sp + y * width + x] += _dout * _g;
- } else {
- int i = z - width;
- int j = i < y ? i : i + 1;
-
- T _g = g[(batch * chn + plane) * sp + j * width + x];
- dw[(batch * len + width + i) * sp + y * width + x] += _dout * _g;
- }
- }
+
+ int z = blockIdx.z % len;
+ int batch = blockIdx.z / len;
+
+ if (x < width && y < height) {
+ int widx = (batch * len + z) * sp + y * width + x;
+ int dout_idx = batch * chn * sp + y * width + x;
+ int gidx = batch * chn * sp;
+ if (z < width) {
+ gidx += y * width + z;
+ } else {
+ int j = z - width;
+ j = j < y ? j : j + 1;
+ gidx += j * width + x;
+ }
+ for (int plane = 0; plane < chn; plane++) {
+ dw[widx] += dout[dout_idx + plane * sp] * g[gidx + plane * sp];
}
}
}
@@ -161,25 +147,21 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
- int plane = blockIdx.z;
-
- if (x < width && y < height && plane < chn) {
- for (int batch = 0; batch < num; ++batch) {
- for (int i = 0; i < width; ++i) {
- T _dout = dout[(batch * chn + plane) * sp + y * width + i];
- T _w = weight[(batch * len + x) * sp + y * width + i];
- dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
- }
- for (int i = 0; i < height; ++i) {
- if (i == y) continue;
- int j = i > y ? y : y - 1;
-
- T _dout = dout[(batch * chn + plane) * sp + i * width + x];
- T _w = weight[(batch * len + width + j) * sp + i * width + x];
- dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
- }
+ int plane = blockIdx.z % chn;
+ int batch = blockIdx.z / chn;
+ int index = (batch * chn + plane) * sp + y * width + x;
+
+ if (x < width && y < height) {
+ for (int i = 0; i < width; ++i) {
+ dg[index] += dout[(batch * chn + plane) * sp + y * width + i] *
+ weight[(batch * len + x) * sp + y * width + i];
+ }
+ for (int i = 0; i < height; ++i) {
+ if (i == y) continue;
+ int j = i > y ? y : y - 1;
+ dg[index] += dout[(batch * chn + plane) * sp + i * width + x] *
+ weight[(batch * len + width + j) * sp + i * width + x];
}
}
}
-
#endif // CC_ATTENTION_CUDA_KERNEL_CUH
diff --git a/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh b/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh
index 04bf5c308d..ca0e91a252 100644
--- a/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh
+++ b/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh
@@ -66,11 +66,16 @@
#ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
#define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
+#include
+#ifdef MMCV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
-#else
+#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
-#endif
+#endif // MMCV_USE_PARROTS
+#endif // MMCV_WITH_TRT
template
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
diff --git a/mmcv/ops/csrc/parrots/bbox_overlaps_parrots.cpp b/mmcv/ops/csrc/parrots/bbox_overlaps_parrots.cpp
index 3c678a818d..35bb5f5c87 100644
--- a/mmcv/ops/csrc/parrots/bbox_overlaps_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/bbox_overlaps_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
/*
* void bbox_overlaps_cuda(const Tensor bboxes1, const Tensor bboxes2, Tensor
* ious, const int mode, const bool aligned, const int offset);
@@ -35,3 +36,4 @@ PARROTS_EXTENSION_REGISTER(bbox_overlaps)
.output(1)
.apply(bbox_overlaps_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/border_align.cpp b/mmcv/ops/csrc/parrots/border_align.cpp
new file mode 100644
index 0000000000..78351e2a5f
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/border_align.cpp
@@ -0,0 +1,67 @@
+#include "pytorch_cpp_helper.hpp"
+
+#ifdef MMCV_WITH_CUDA
+void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
+ const Tensor &boxes, Tensor output,
+ Tensor argmax_idx,
+ const int pool_size);
+
+void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
+ const Tensor &boxes,
+ const Tensor &argmax_idx,
+ Tensor grad_input,
+ const int pool_size);
+
+void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
+ Tensor output, Tensor argmax_idx,
+ const int pool_size) {
+ BorderAlignForwardCUDAKernelLauncher(input, boxes, output, argmax_idx,
+ pool_size);
+}
+
+void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
+ const Tensor &argmax_idx, Tensor grad_input,
+ const int pool_size) {
+ BorderAlignBackwardCUDAKernelLauncher(grad_output, boxes, argmax_idx,
+ grad_input, pool_size);
+}
+#endif
+
+void border_align_forward(const Tensor &input, const Tensor &boxes,
+ Tensor output, Tensor argmax_idx,
+ const int pool_size) {
+ if (input.device().is_cuda()) {
+#ifdef MMCV_WITH_CUDA
+ CHECK_CUDA_INPUT(input);
+ CHECK_CUDA_INPUT(boxes);
+ CHECK_CUDA_INPUT(output);
+ CHECK_CUDA_INPUT(argmax_idx);
+
+ border_align_forward_cuda(input, boxes, output, argmax_idx, pool_size);
+#else
+ AT_ERROR("BorderAlign is not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("BorderAlign is not implemented on CPU");
+ }
+}
+
+void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
+ const Tensor &argmax_idx, Tensor grad_input,
+ const int pool_size) {
+ if (grad_output.device().is_cuda()) {
+#ifdef MMCV_WITH_CUDA
+ CHECK_CUDA_INPUT(grad_output);
+ CHECK_CUDA_INPUT(boxes);
+ CHECK_CUDA_INPUT(argmax_idx);
+ CHECK_CUDA_INPUT(grad_input);
+
+ border_align_backward_cuda(grad_output, boxes, argmax_idx, grad_input,
+ pool_size);
+#else
+ AT_ERROR("BorderAlign is not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("BorderAlign is not implemented on CPU");
+ }
+}
diff --git a/mmcv/ops/csrc/parrots/border_align_cuda.cu b/mmcv/ops/csrc/parrots/border_align_cuda.cu
new file mode 100644
index 0000000000..06ba452f65
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/border_align_cuda.cu
@@ -0,0 +1,67 @@
+#include "border_align_cuda_kernel.cuh"
+#include "pytorch_cuda_helper.hpp"
+
+void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
+ const Tensor &boxes, Tensor output,
+ Tensor argmax_idx,
+ const int pool_size) {
+ // shape assertion
+ AT_ASSERTM(input.ndimension() == 4,
+ "non-empty 4D(batch mode) tensor expected for input feature");
+ AT_ASSERTM(boxes.ndimension() == 3,
+ "boxes must be 3D tensor with size of [B, H*W, 4]");
+
+ int batch_size = input.size(0);
+ int feat_channels = input.size(1);
+ int channels = feat_channels / 4;
+ int height = input.size(2);
+ int width = input.size(3);
+ // shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format
+ int box_size = boxes.size(1);
+ // shape [N, channels, box_size, 4] for output
+ int nthreads = batch_size * channels * box_size;
+
+ at::cuda::CUDAGuard device_guard(input.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 block(128, 4);
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ input.scalar_type(), "border_align_forward_cuda_kernel", [&] {
+ border_align_forward_cuda_kernel
+ <<>>(
+ nthreads, input.data_ptr(),
+ boxes.data_ptr(), output.data_ptr(),
+ argmax_idx.data_ptr(), channels, box_size, height, width,
+ pool_size);
+ });
+
+ AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
+ const Tensor &boxes,
+ const Tensor &argmax_idx,
+ Tensor grad_input,
+ const int pool_size) {
+ int batch_size = grad_input.size(0);
+ int feat_channels = grad_input.size(1);
+ int channels = feat_channels / 4;
+ int height = grad_input.size(2);
+ int width = grad_input.size(3);
+ int box_size = boxes.size(1);
+ int nthreads = batch_size * channels * box_size;
+
+ at::cuda::CUDAGuard device_guard(grad_output.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 block(128, 4);
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad_output.scalar_type(), "border_align_backward_cuda_kernel", [&] {
+ border_align_backward_cuda_kernel
+ <<>>(
+ nthreads, grad_output.data_ptr(),
+ boxes.data_ptr(), argmax_idx.data_ptr(),
+ grad_input.data_ptr(), channels, box_size, height,
+ width, pool_size);
+ });
+
+ AT_CUDA_CHECK(cudaGetLastError());
+}
diff --git a/mmcv/ops/csrc/parrots/border_align_parrots.cpp b/mmcv/ops/csrc/parrots/border_align_parrots.cpp
new file mode 100644
index 0000000000..a4564b09e1
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/border_align_parrots.cpp
@@ -0,0 +1,50 @@
+#include
+#include
+#include
+
+#include "border_align_pytorch.h"
+
+using namespace parrots;
+
+void border_align_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
+ const OperatorBase::in_list_t& ins,
+ OperatorBase::out_list_t& outs) {
+ int pool_size;
+ SSAttrs(attr).get("pool_size", pool_size).done();
+
+ const auto& input = buildATensor(ctx, ins[0]);
+ const auto& boxes = buildATensor(ctx, ins[1]);
+
+ auto output = buildATensor(ctx, outs[0]);
+ auto argmax_idx = buildATensor(ctx, outs[1]);
+ border_align_forward_cuda(input, boxes, output, argmax_idx, pool_size);
+}
+
+void border_align_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
+ const OperatorBase::in_list_t& ins,
+ OperatorBase::out_list_t& outs) {
+ int pool_size;
+ SSAttrs(attr).get("pool_size", pool_size).done();
+
+ const auto& top_grad = buildATensor(ctx, ins[0]);
+ const auto& boxes = buildATensor(ctx, ins[1]);
+ const auto& argmax_idx = buildATensor(ctx, ins[2]);
+
+ auto bottom_grad = buildATensor(ctx, outs[0]);
+ border_align_backward_cuda(top_grad, boxes, argmax_idx, bottom_grad,
+ pool_size);
+}
+
+PARROTS_EXTENSION_REGISTER(border_align_forward)
+ .attr("pool_size")
+ .input(2)
+ .output(2)
+ .apply(border_align_forward_cuda_parrots)
+ .done();
+
+PARROTS_EXTENSION_REGISTER(border_align_backward)
+ .attr("pool_size")
+ .input(3)
+ .output(1)
+ .apply(border_align_backward_cuda_parrots)
+ .done();
diff --git a/mmcv/ops/csrc/parrots/border_align_pytorch.h b/mmcv/ops/csrc/parrots/border_align_pytorch.h
new file mode 100644
index 0000000000..54ff54c34b
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/border_align_pytorch.h
@@ -0,0 +1,16 @@
+#ifndef BORDER_ALIGN_PYTORCH_H
+#define BORDER_ALIGN_PYTORCH_H
+#include
+using namespace at;
+
+#ifdef MMCV_WITH_CUDA
+void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
+ Tensor output, Tensor argmax_idx,
+ const int pool_size);
+
+void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
+ const Tensor &argmax_idx, Tensor grad_input,
+ const int pool_size);
+#endif
+
+#endif // BORDER_ALIGN_PYTORCH_H
diff --git a/mmcv/ops/csrc/parrots/carafe_naive_parrots.cpp b/mmcv/ops/csrc/parrots/carafe_naive_parrots.cpp
index 34aadf26de..78dfe09d42 100644
--- a/mmcv/ops/csrc/parrots/carafe_naive_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/carafe_naive_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
/*void carafe_naive_forward_cuda(Tensor features, Tensor masks, Tensor output,
* int kernel_size, int group_size,
* int scale_factor)
@@ -69,3 +70,4 @@ PARROTS_EXTENSION_REGISTER(carafe_naive_backward)
.output(2)
.apply(carafe_naive_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/carafe_parrots.cpp b/mmcv/ops/csrc/parrots/carafe_parrots.cpp
index 8fb32573fa..413778b55a 100644
--- a/mmcv/ops/csrc/parrots/carafe_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/carafe_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
/*
* void carafe_forward_cuda(Tensor features, Tensor masks, Tensor rfeatures,
* Tensor routput, Tensor rmasks, Tensor output,
@@ -83,3 +84,4 @@ PARROTS_EXTENSION_REGISTER(carafe_backward)
.output(6)
.apply(carafe_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/cc_attention_cuda.cu b/mmcv/ops/csrc/parrots/cc_attention_cuda.cu
index b948d5406a..fd4e7fd128 100644
--- a/mmcv/ops/csrc/parrots/cc_attention_cuda.cu
+++ b/mmcv/ops/csrc/parrots/cc_attention_cuda.cu
@@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = h + w;
- dim3 blocks(d1, d2, d3);
+ int d3 = h + w - 1;
+ dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<<>>(
@@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = c;
+ int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
@@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = c;
+ int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
@@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = h + w;
- dim3 blocks(d1, d2, d3);
+ int d3 = h + w - 1;
+ dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
@@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr(),
dw.contiguous().data_ptr(), n, c, h, w);
});
-
+ d3 = c * n;
+ blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<<>>(
dout.contiguous().data_ptr(),
diff --git a/mmcv/ops/csrc/parrots/cc_attention_parrots.cpp b/mmcv/ops/csrc/parrots/cc_attention_parrots.cpp
index 150d3ec370..a51e46c389 100644
--- a/mmcv/ops/csrc/parrots/cc_attention_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/cc_attention_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
/*void ca_forward_cuda(const Tensor t, const Tensor f, Tensor weight);*/
void ca_forward_cuda_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
@@ -77,3 +78,4 @@ PARROTS_EXTENSION_REGISTER(ca_map_backward)
.output(2)
.apply(ca_map_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/deform_conv_parrots.cpp b/mmcv/ops/csrc/parrots/deform_conv_parrots.cpp
index 3347882f83..949f6b4279 100644
--- a/mmcv/ops/csrc/parrots/deform_conv_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/deform_conv_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
/*void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset,
* Tensor output, Tensor columns, Tensor ones,
* int kW, int kH, int dW, int dH, int padW,
@@ -177,3 +178,4 @@ PARROTS_EXTENSION_REGISTER(deform_conv_backward_parameters)
.output(3)
.apply(deform_conv_backward_parameters_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/deform_roi_pool_parrots.cpp b/mmcv/ops/csrc/parrots/deform_roi_pool_parrots.cpp
index 275a7661b2..2fb8b371bb 100644
--- a/mmcv/ops/csrc/parrots/deform_roi_pool_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/deform_roi_pool_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
/*void deform_roi_pool_forward_cuda(Tensor input, Tensor rois, Tensor offset,
* Tensor output, int pooled_height,
* int pooled_width, float spatial_scale,
@@ -97,3 +98,4 @@ PARROTS_EXTENSION_REGISTER(deform_roi_pool_backward)
.output(2)
.apply(deform_roi_pool_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/focal_loss_parrots.cpp b/mmcv/ops/csrc/parrots/focal_loss_parrots.cpp
index 46eea40561..3511d89a99 100644
--- a/mmcv/ops/csrc/parrots/focal_loss_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/focal_loss_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
void sigmoid_focal_loss_forward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
@@ -108,3 +109,4 @@ PARROTS_EXTENSION_REGISTER(softmax_focal_loss_backward)
.output(2)
.apply(softmax_focal_loss_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/masked_conv2d_parrots.cpp b/mmcv/ops/csrc/parrots/masked_conv2d_parrots.cpp
index e01452e80d..5a9ff64f75 100644
--- a/mmcv/ops/csrc/parrots/masked_conv2d_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/masked_conv2d_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
void masked_im2col_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
@@ -67,3 +68,4 @@ PARROTS_EXTENSION_REGISTER(masked_col2im_forward)
.output(1)
.apply(masked_col2im_forward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/modulated_deform_conv_parrots.cpp b/mmcv/ops/csrc/parrots/modulated_deform_conv_parrots.cpp
index 837a9db306..de5ff63e0a 100644
--- a/mmcv/ops/csrc/parrots/modulated_deform_conv_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/modulated_deform_conv_parrots.cpp
@@ -6,6 +6,7 @@
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
void modulated_deform_conv_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
@@ -114,3 +115,4 @@ PARROTS_EXTENSION_REGISTER(modulated_deform_conv_backward)
.output(7)
.apply(modulated_deform_conv_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/ms_deform_attn.cpp b/mmcv/ops/csrc/parrots/ms_deform_attn.cpp
new file mode 100644
index 0000000000..9bfabdda58
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/ms_deform_attn.cpp
@@ -0,0 +1,79 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from
+*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "pytorch_cpp_helper.hpp"
+
+#ifdef MMCV_WITH_CUDA
+Tensor ms_deform_attn_cuda_forward(const Tensor &value,
+ const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight,
+ const int im2col_step);
+
+void ms_deform_attn_cuda_backward(
+ const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index, const Tensor &sampling_loc,
+ const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
+ Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
+
+#endif
+
+Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight,
+ const int im2col_step) {
+ if (value.type().is_cuda()) {
+#ifdef MMCV_WITH_CUDA
+ CHECK_CUDA_INPUT(value)
+ CHECK_CUDA_INPUT(spatial_shapes)
+ CHECK_CUDA_INPUT(level_start_index)
+ CHECK_CUDA_INPUT(sampling_loc)
+ CHECK_CUDA_INPUT(attn_weight)
+ return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
+ sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight,
+ const Tensor &grad_output, Tensor &grad_value,
+ Tensor &grad_sampling_loc,
+ Tensor &grad_attn_weight, const int im2col_step) {
+ if (value.type().is_cuda()) {
+#ifdef MMCV_WITH_CUDA
+ CHECK_CUDA_INPUT(value)
+ CHECK_CUDA_INPUT(spatial_shapes)
+ CHECK_CUDA_INPUT(level_start_index)
+ CHECK_CUDA_INPUT(sampling_loc)
+ CHECK_CUDA_INPUT(attn_weight)
+ CHECK_CUDA_INPUT(grad_output)
+ CHECK_CUDA_INPUT(grad_value)
+ CHECK_CUDA_INPUT(grad_sampling_loc)
+ CHECK_CUDA_INPUT(grad_attn_weight)
+ ms_deform_attn_cuda_backward(value, spatial_shapes, level_start_index,
+ sampling_loc, attn_weight, grad_output,
+ grad_value, grad_sampling_loc,
+ grad_attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("Not implemented on the CPU");
+ }
+}
diff --git a/mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000..693131b382
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu
@@ -0,0 +1,360 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from
+*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size,
+ const int num_heads, const int channels,
+ const int num_levels, const int num_query,
+ const int num_point, scalar_t *data_col) {
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index,
+ data_sampling_loc, data_attn_weight, batch_size, spatial_size,
+ num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+void ms_deformable_col2im_cuda(
+ cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ const int num_threads =
+ (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024) {
+ if ((channels & 1023) == 0) {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels,
+ num_query, num_point, grad_value, grad_sampling_loc,
+ grad_attn_weight);
+ } else {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ } else {
+ switch (channels) {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc,
+ data_attn_weight, batch_size, spatial_size, num_heads,
+ channels, num_levels, num_query, num_point, grad_value,
+ grad_sampling_loc, grad_attn_weight);
+ break;
+ default:
+ if (channels < 64) {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels,
+ num_query, num_point, grad_value, grad_sampling_loc,
+ grad_attn_weight);
+ } else {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels, grad_col, data_value, data_spatial_shapes,
+ data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels,
+ num_query, num_point, grad_value, grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step) {
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(),
+ "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(),
+ "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(),
+ "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(),
+ "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(),
+ "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(),
+ "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(),
+ "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
+ batch, im2col_step_);
+
+ auto output =
+ at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view(
+ {batch / im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch / im2col_step_; ++n) {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(
+ value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(
+ at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(), level_start_index.data(),
+ sampling_loc.data() +
+ n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() +
+ n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query,
+ num_point, columns.data());
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads * channels});
+
+ return output;
+}
+
+void ms_deform_attn_cuda_backward(
+ const at::Tensor &value, const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight, const at::Tensor &grad_output,
+ at::Tensor &grad_value, at::Tensor &grad_sampling_loc,
+ at::Tensor &grad_attn_weight, const int im2col_step) {
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(),
+ "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(),
+ "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(),
+ "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(),
+ "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(),
+ "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(),
+ "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(),
+ "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(),
+ "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
+ batch, im2col_step_);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view(
+ {batch / im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch / im2col_step_; ++n) {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(
+ value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(
+ at::cuda::getCurrentCUDAStream(), grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(), level_start_index.data(),
+ sampling_loc.data() +
+ n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() +
+ n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query,
+ num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() +
+ n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() +
+ n * im2col_step_ * per_attn_weight_size);
+ }));
+ }
+}
diff --git a/mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp b/mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp
new file mode 100644
index 0000000000..8b236cc822
--- /dev/null
+++ b/mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp
@@ -0,0 +1,68 @@
+#include
+
+#include
+#include
+#include
+using namespace at;
+using namespace parrots;
+
+Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight, const int im2col_step);
+
+void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight,
+ const Tensor &grad_output, Tensor &grad_value,
+ Tensor &grad_sampling_loc,
+ Tensor &grad_attn_weight, const int im2col_step);
+
+void ms_deform_attn_forward_parrots(CudaContext &ctx, const SSElement &attr,
+ const OperatorBase::in_list_t &ins,
+ OperatorBase::out_list_t &outs) {
+ int im2col_step;
+ SSAttrs(attr).get("im2col_step", im2col_step).done();
+ const auto &value = buildATensor(ctx, ins[0]);
+ const auto &spatial_shapes = buildATensor(ctx, ins[1]);
+ const auto &level_start_index = buildATensor(ctx, ins[2]);
+ const auto &sampling_loc = buildATensor(ctx, ins[3]);
+ const auto &attn_weight = buildATensor(ctx, ins[4]);
+ auto out = ms_deform_attn_forward(value, spatial_shapes, level_start_index,
+ sampling_loc, attn_weight, im2col_step);
+ updateDArray(ctx, out, outs[0]);
+}
+
+void ms_deform_attn_backward_parrots(CudaContext &ctx, const SSElement &attr,
+ const OperatorBase::in_list_t &ins,
+ OperatorBase::out_list_t &outs) {
+ int im2col_step;
+ SSAttrs(attr).get("im2col_step", im2col_step).done();
+ const auto &value = buildATensor(ctx, ins[0]);
+ const auto &spatial_shapes = buildATensor(ctx, ins[1]);
+ const auto &level_start_index = buildATensor(ctx, ins[2]);
+ const auto &sampling_loc = buildATensor(ctx, ins[3]);
+ const auto &attn_weight = buildATensor(ctx, ins[4]);
+ const auto &grad_output = buildATensor(ctx, ins[5]);
+ auto grad_value = buildATensor(ctx, outs[0]);
+ auto grad_sampling_loc = buildATensor(ctx, outs[1]);
+ auto grad_attn_weight = buildATensor(ctx, outs[2]);
+ ms_deform_attn_backward(value, spatial_shapes, level_start_index,
+ sampling_loc, attn_weight, grad_output, grad_value,
+ grad_sampling_loc, grad_attn_weight, im2col_step);
+}
+
+PARROTS_EXTENSION_REGISTER(ms_deform_attn_forward)
+ .attr("im2col_step")
+ .input(5)
+ .output(1)
+ .apply(ms_deform_attn_forward_parrots)
+ .done();
+
+PARROTS_EXTENSION_REGISTER(ms_deform_attn_backward)
+ .attr("im2col_step")
+ .input(6)
+ .output(3)
+ .apply(ms_deform_attn_backward_parrots)
+ .done();
diff --git a/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp b/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp
index 8a6a577cb1..8cdbdbbbd7 100644
--- a/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp
@@ -5,6 +5,7 @@
#include "sync_bn_pytorch.h"
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
void sync_bn_forward_mean_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
@@ -106,3 +107,4 @@ PARROTS_EXTENSION_REGISTER(sync_bn_backward_data)
.output(1)
.apply(sync_bn_backward_data_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/parrots/tin_shift_parrots.cpp b/mmcv/ops/csrc/parrots/tin_shift_parrots.cpp
index 48c7df4f2b..e2f7cc0472 100644
--- a/mmcv/ops/csrc/parrots/tin_shift_parrots.cpp
+++ b/mmcv/ops/csrc/parrots/tin_shift_parrots.cpp
@@ -5,6 +5,7 @@
#include "tin_shift_pytorch.h"
using namespace parrots;
+#ifdef MMCV_WITH_CUDA
void tin_shift_forward_cuda_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
@@ -34,3 +35,4 @@ PARROTS_EXTENSION_REGISTER(tin_shift_backward)
.output(1)
.apply(tin_shift_backward_cuda_parrots)
.done();
+#endif
diff --git a/mmcv/ops/csrc/pytorch/border_align.cpp b/mmcv/ops/csrc/pytorch/border_align.cpp
new file mode 100644
index 0000000000..78351e2a5f
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/border_align.cpp
@@ -0,0 +1,67 @@
+#include "pytorch_cpp_helper.hpp"
+
+#ifdef MMCV_WITH_CUDA
+void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
+ const Tensor &boxes, Tensor output,
+ Tensor argmax_idx,
+ const int pool_size);
+
+void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
+ const Tensor &boxes,
+ const Tensor &argmax_idx,
+ Tensor grad_input,
+ const int pool_size);
+
+void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
+ Tensor output, Tensor argmax_idx,
+ const int pool_size) {
+ BorderAlignForwardCUDAKernelLauncher(input, boxes, output, argmax_idx,
+ pool_size);
+}
+
+void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
+ const Tensor &argmax_idx, Tensor grad_input,
+ const int pool_size) {
+ BorderAlignBackwardCUDAKernelLauncher(grad_output, boxes, argmax_idx,
+ grad_input, pool_size);
+}
+#endif
+
+void border_align_forward(const Tensor &input, const Tensor &boxes,
+ Tensor output, Tensor argmax_idx,
+ const int pool_size) {
+ if (input.device().is_cuda()) {
+#ifdef MMCV_WITH_CUDA
+ CHECK_CUDA_INPUT(input);
+ CHECK_CUDA_INPUT(boxes);
+ CHECK_CUDA_INPUT(output);
+ CHECK_CUDA_INPUT(argmax_idx);
+
+ border_align_forward_cuda(input, boxes, output, argmax_idx, pool_size);
+#else
+ AT_ERROR("BorderAlign is not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("BorderAlign is not implemented on CPU");
+ }
+}
+
+void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
+ const Tensor &argmax_idx, Tensor grad_input,
+ const int pool_size) {
+ if (grad_output.device().is_cuda()) {
+#ifdef MMCV_WITH_CUDA
+ CHECK_CUDA_INPUT(grad_output);
+ CHECK_CUDA_INPUT(boxes);
+ CHECK_CUDA_INPUT(argmax_idx);
+ CHECK_CUDA_INPUT(grad_input);
+
+ border_align_backward_cuda(grad_output, boxes, argmax_idx, grad_input,
+ pool_size);
+#else
+ AT_ERROR("BorderAlign is not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("BorderAlign is not implemented on CPU");
+ }
+}
diff --git a/mmcv/ops/csrc/pytorch/border_align_cuda.cu b/mmcv/ops/csrc/pytorch/border_align_cuda.cu
new file mode 100644
index 0000000000..06ba452f65
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/border_align_cuda.cu
@@ -0,0 +1,67 @@
+#include "border_align_cuda_kernel.cuh"
+#include "pytorch_cuda_helper.hpp"
+
+void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
+ const Tensor &boxes, Tensor output,
+ Tensor argmax_idx,
+ const int pool_size) {
+ // shape assertion
+ AT_ASSERTM(input.ndimension() == 4,
+ "non-empty 4D(batch mode) tensor expected for input feature");
+ AT_ASSERTM(boxes.ndimension() == 3,
+ "boxes must be 3D tensor with size of [B, H*W, 4]");
+
+ int batch_size = input.size(0);
+ int feat_channels = input.size(1);
+ int channels = feat_channels / 4;
+ int height = input.size(2);
+ int width = input.size(3);
+ // shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format
+ int box_size = boxes.size(1);
+ // shape [N, channels, box_size, 4] for output
+ int nthreads = batch_size * channels * box_size;
+
+ at::cuda::CUDAGuard device_guard(input.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 block(128, 4);
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ input.scalar_type(), "border_align_forward_cuda_kernel", [&] {
+ border_align_forward_cuda_kernel
+ <<>>(
+ nthreads, input.data_ptr(),
+ boxes.data_ptr(), output.data_ptr(),
+ argmax_idx.data_ptr(), channels, box_size, height, width,
+ pool_size);
+ });
+
+ AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
+ const Tensor &boxes,
+ const Tensor &argmax_idx,
+ Tensor grad_input,
+ const int pool_size) {
+ int batch_size = grad_input.size(0);
+ int feat_channels = grad_input.size(1);
+ int channels = feat_channels / 4;
+ int height = grad_input.size(2);
+ int width = grad_input.size(3);
+ int box_size = boxes.size(1);
+ int nthreads = batch_size * channels * box_size;
+
+ at::cuda::CUDAGuard device_guard(grad_output.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 block(128, 4);
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad_output.scalar_type(), "border_align_backward_cuda_kernel", [&] {
+ border_align_backward_cuda_kernel
+ <<>>(
+ nthreads, grad_output.data_ptr(),
+ boxes.data_ptr(), argmax_idx.data_ptr(),
+ grad_input.data_ptr(), channels, box_size, height,
+ width, pool_size);
+ });
+
+ AT_CUDA_CHECK(cudaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu b/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu
index b948d5406a..fd4e7fd128 100644
--- a/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu
+++ b/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu
@@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = h + w;
- dim3 blocks(d1, d2, d3);
+ int d3 = h + w - 1;
+ dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<<>>(
@@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = c;
+ int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
@@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = c;
+ int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
@@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
- int d3 = h + w;
- dim3 blocks(d1, d2, d3);
+ int d3 = h + w - 1;
+ dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
@@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr(),
dw.contiguous().data_ptr(), n, c, h, w);
});
-
+ d3 = c * n;
+ blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<<>>(
dout.contiguous().data_ptr(),
diff --git a/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp b/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
index 9bcee5c243..9bfabdda58 100644
--- a/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
+++ b/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
@@ -19,11 +19,11 @@ Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &attn_weight,
const int im2col_step);
-std::vector ms_deform_attn_cuda_backward(
+void ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
- const Tensor &attn_weight, const Tensor &grad_output,
- const int im2col_step);
+ const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
+ Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
#endif
@@ -48,13 +48,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
AT_ERROR("Not implemented on the CPU");
}
-std::vector ms_deform_attn_backward(const Tensor &value,
- const Tensor &spatial_shapes,
- const Tensor &level_start_index,
- const Tensor &sampling_loc,
- const Tensor &attn_weight,
- const Tensor &grad_output,
- const int im2col_step) {
+void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight,
+ const Tensor &grad_output, Tensor &grad_value,
+ Tensor &grad_sampling_loc,
+ Tensor &grad_attn_weight, const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
@@ -63,12 +63,17 @@ std::vector ms_deform_attn_backward(const Tensor &value,
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
- return ms_deform_attn_cuda_backward(value, spatial_shapes,
- level_start_index, sampling_loc,
- attn_weight, grad_output, im2col_step);
+ CHECK_CUDA_INPUT(grad_value)
+ CHECK_CUDA_INPUT(grad_sampling_loc)
+ CHECK_CUDA_INPUT(grad_attn_weight)
+ ms_deform_attn_cuda_backward(value, spatial_shapes, level_start_index,
+ sampling_loc, attn_weight, grad_output,
+ grad_value, grad_sampling_loc,
+ grad_attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
+ } else {
+ AT_ERROR("Not implemented on the CPU");
}
- AT_ERROR("Not implemented on the CPU");
}
diff --git a/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
index 1cd67403f0..693131b382 100644
--- a/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
+++ b/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
@@ -286,11 +286,12 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
return output;
}
-std::vector ms_deform_attn_cuda_backward(
+void ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
- const int im2col_step) {
+ at::Tensor &grad_value, at::Tensor &grad_sampling_loc,
+ at::Tensor &grad_attn_weight, const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
@@ -328,10 +329,6 @@ std::vector ms_deform_attn_cuda_backward(
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
- auto grad_value = at::zeros_like(value);
- auto grad_sampling_loc = at::zeros_like(sampling_loc);
- auto grad_attn_weight = at::zeros_like(attn_weight);
-
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
@@ -360,6 +357,4 @@ std::vector ms_deform_attn_cuda_backward(
n * im2col_step_ * per_attn_weight_size);
}));
}
-
- return {grad_value, grad_sampling_loc, grad_attn_weight};
}
diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp
index 6e1096b6fb..0b88e55658 100644
--- a/mmcv/ops/csrc/pytorch/pybind.cpp
+++ b/mmcv/ops/csrc/pytorch/pybind.cpp
@@ -97,13 +97,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);
-std::vector ms_deform_attn_backward(const Tensor &value,
- const Tensor &spatial_shapes,
- const Tensor &level_start_index,
- const Tensor &sampling_loc,
- const Tensor &attn_weight,
- const Tensor &grad_output,
- const int im2col_step);
+void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
+ const Tensor &level_start_index,
+ const Tensor &sampling_loc,
+ const Tensor &attn_weight,
+ const Tensor &grad_output, Tensor &grad_value,
+ Tensor &grad_sampling_loc,
+ Tensor &grad_attn_weight, const int im2col_step);
Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
@@ -222,6 +222,14 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
int pooled_width, float spatial_scale,
int sample_num, bool aligned, bool clockwise);
+void border_align_forward(const Tensor &input, const Tensor &boxes,
+ Tensor output, Tensor argmax_idx,
+ const int pool_size);
+
+void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
+ const Tensor &argmax_idx, Tensor grad_input,
+ const int pool_size);
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@@ -445,5 +453,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("grad_output"),
- py::arg("im2col_step"));
+ py::arg("grad_value"), py::arg("grad_sampling_loc"),
+ py::arg("grad_attn_weight"), py::arg("im2col_step"));
+ m.def("border_align_forward", &border_align_forward,
+ "forward function of border_align", py::arg("input"), py::arg("boxes"),
+ py::arg("output"), py::arg("argmax_idx"), py::arg("pool_size"));
+ m.def("border_align_backward", &border_align_backward,
+ "backward function of border_align", py::arg("grad_output"),
+ py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"),
+ py::arg("pool_size"));
}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu
index 5b85a4e567..8ddcca9703 100644
--- a/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu
@@ -1,3 +1,5 @@
+#include
+
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
@@ -64,3 +66,25 @@ void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
template void memcpyPermute(float *dst, const float *src, int *src_size,
int *permute, int src_dim,
cudaStream_t stream);
+
+template <>
+cublasStatus_t cublasGemmWrap(cublasHandle_t handle,
+ cublasOperation_t transa,
+ cublasOperation_t transb, int m, int n,
+ int k, const float *alpha, const float *A,
+ int lda, const float *B, int ldb,
+ const float *beta, float *C, int ldc) {
+ return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
+ beta, C, ldc);
+}
+
+template <>
+cublasStatus_t cublasGemmWrap(cublasHandle_t handle,
+ cublasOperation_t transa,
+ cublasOperation_t transb, int m, int n,
+ int k, const half *alpha, const half *A,
+ int lda, const half *B, int ldb,
+ const half *beta, half *C, int ldc) {
+ return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
+ beta, C, ldc);
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp
new file mode 100644
index 0000000000..2e920cfed0
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp
@@ -0,0 +1,241 @@
+#include "trt_cummaxmin.hpp"
+
+#include
+
+#include "trt_serialize.hpp"
+
+void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream);
+
+void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream);
+
+namespace {
+static const char *PLUGIN_VERSION{"1"};
+static const char *CUMMAXMIN_PLUGIN_NAME{"cummaxmin"};
+static const char *CUMMAX_PLUGIN_NAME{"cummax"};
+static const char *CUMMIN_PLUGIN_NAME{"cummin"};
+} // namespace
+
+CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string &name, int dim,
+ TRT_CUMCMPTYPE cumType)
+ : mLayerName(name), mDim(dim), mCumType(cumType) {}
+
+CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string name,
+ const void *data, size_t length)
+ : mLayerName(name) {
+ deserialize_value(&data, &length, &mDim);
+ deserialize_value(&data, &length, &mCumType);
+}
+
+CumMaxMinPluginDynamic::~CumMaxMinPluginDynamic() {}
+
+nvinfer1::IPluginV2DynamicExt *CumMaxMinPluginDynamic::clone() const {
+ CumMaxMinPluginDynamic *plugin =
+ new CumMaxMinPluginDynamic(mLayerName, mDim, mCumType);
+ plugin->setPluginNamespace(getPluginNamespace());
+
+ return plugin;
+}
+
+nvinfer1::DimsExprs CumMaxMinPluginDynamic::getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
+ nvinfer1::IExprBuilder &exprBuilder) {
+ return inputs[0];
+}
+
+bool CumMaxMinPluginDynamic::supportsFormatCombination(
+ int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
+ int nbOutputs) {
+ switch (pos) {
+ // input[0]
+ case 0:
+ return (inOut[pos].type == nvinfer1::DataType::kFLOAT ||
+ inOut[pos].type == nvinfer1::DataType::kINT32) &&
+ inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
+ // output[0]
+ case 1:
+ return inOut[pos].type == inOut[0].type &&
+ inOut[pos].format == inOut[0].format;
+ // output[1]
+ case 2:
+ return inOut[pos].type == nvinfer1::DataType::kINT32 &&
+ inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
+ default:
+ return false;
+ }
+}
+
+void CumMaxMinPluginDynamic::configurePlugin(
+ const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
+
+size_t CumMaxMinPluginDynamic::getWorkspaceSize(
+ const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
+ const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
+ int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
+}
+
+int CumMaxMinPluginDynamic::enqueue(
+ const nvinfer1::PluginTensorDesc *inputDesc,
+ const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
+ void *const *outputs, void *workSpace, cudaStream_t stream) {
+ const void *input = inputs[0];
+ void *output_value = outputs[0];
+ int *output_index = (int *)outputs[1];
+
+ const int *dims = &(inputDesc[0].dims.d[0]);
+ int nbDims = inputDesc[0].dims.nbDims;
+
+ switch (inputDesc[0].type) {
+ case nvinfer1::DataType::kFLOAT:
+ CumMaxMinForwardLauncher_float((float *)input, (float *)output_value,
+ output_index, dims, nbDims, mDim,
+ int(mCumType), stream);
+ break;
+ case nvinfer1::DataType::kINT32:
+ CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value,
+ output_index, dims, nbDims, mDim,
+ int(mCumType), stream);
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+nvinfer1::DataType CumMaxMinPluginDynamic::getOutputDataType(
+ int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
+ switch (index) {
+ case 0:
+ return inputTypes[0];
+ case 1:
+ return nvinfer1::DataType::kINT32;
+ default:
+ break;
+ }
+}
+
+// IPluginV2 Methods
+const char *CumMaxMinPluginDynamic::getPluginType() const {
+ switch (mCumType) {
+ case TRT_CUMCMPTYPE::TRT_CUMMAX:
+ return CUMMAX_PLUGIN_NAME;
+ case TRT_CUMCMPTYPE::TRT_CUMMIN:
+ return CUMMIN_PLUGIN_NAME;
+ default:
+ return "UnknownCumType";
+ }
+}
+
+const char *CumMaxMinPluginDynamic::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+int CumMaxMinPluginDynamic::getNbOutputs() const { return 2; }
+
+int CumMaxMinPluginDynamic::initialize() { return 0; }
+
+void CumMaxMinPluginDynamic::terminate() {}
+
+size_t CumMaxMinPluginDynamic::getSerializationSize() const {
+ return sizeof(mDim) + sizeof(mCumType);
+}
+
+void CumMaxMinPluginDynamic::serialize(void *buffer) const {
+ serialize_value(&buffer, mDim);
+ serialize_value(&buffer, mCumType);
+}
+
+void CumMaxMinPluginDynamic::destroy() {
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+void CumMaxMinPluginDynamic::setPluginNamespace(const char *libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char *CumMaxMinPluginDynamic::getPluginNamespace() const {
+ return mNamespace.c_str();
+}
+
+CumMaxMinPluginDynamicCreator::CumMaxMinPluginDynamicCreator(
+ TRT_CUMCMPTYPE cumType)
+ : mCumType(cumType) {
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("dim"));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+const char *CumMaxMinPluginDynamicCreator::getPluginName() const {
+ return CUMMAXMIN_PLUGIN_NAME;
+}
+
+const char *CumMaxMinPluginDynamicCreator::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+const nvinfer1::PluginFieldCollection *
+CumMaxMinPluginDynamicCreator::getFieldNames() {
+ return &mFC;
+}
+
+nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::createPlugin(
+ const char *name, const nvinfer1::PluginFieldCollection *fc) {
+ int dim = 0;
+
+ for (int i = 0; i < fc->nbFields; i++) {
+ if (fc->fields[i].data == nullptr) {
+ continue;
+ }
+ std::string field_name(fc->fields[i].name);
+
+ if (field_name.compare("dim") == 0) {
+ dim = static_cast(fc->fields[i].data)[0];
+ }
+ }
+
+ CumMaxMinPluginDynamic *plugin =
+ new CumMaxMinPluginDynamic(name, dim, mCumType);
+ plugin->setPluginNamespace(getPluginNamespace());
+ return plugin;
+}
+
+nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::deserializePlugin(
+ const char *name, const void *serialData, size_t serialLength) {
+ // This object will be deleted when the network is destroyed, which will
+ // call FCPluginDynamic::destroy()
+ auto plugin = new CumMaxMinPluginDynamic(name, serialData, serialLength);
+ plugin->setPluginNamespace(getPluginNamespace());
+ return plugin;
+}
+
+void CumMaxMinPluginDynamicCreator::setPluginNamespace(
+ const char *libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char *CumMaxMinPluginDynamicCreator::getPluginNamespace() const {
+ return mNamespace.c_str();
+}
+
+CumMaxPluginDynamicCreator::CumMaxPluginDynamicCreator()
+ : CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMAX) {}
+
+const char *CumMaxPluginDynamicCreator::getPluginName() const {
+ return CUMMAX_PLUGIN_NAME;
+}
+
+CumMinPluginDynamicCreator::CumMinPluginDynamicCreator()
+ : CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMIN) {}
+
+const char *CumMinPluginDynamicCreator::getPluginName() const {
+ return CUMMIN_PLUGIN_NAME;
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu
new file mode 100644
index 0000000000..753104071f
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu
@@ -0,0 +1,89 @@
+
+#include "common_cuda_helper.hpp"
+#include "trt_cuda_helper.cuh"
+#include "trt_plugin_helper.hpp"
+
+using mmcv::TensorDesc;
+
+template
+__global__ void cummaxmin_kernel(const scalar_t *input, scalar_t *output_value,
+ int *output_index, TensorDesc tensor_desc,
+ int cum_dim, int cum_type) {
+ const size_t cum_size = tensor_desc.shape[cum_dim];
+ const size_t cum_stride = tensor_desc.stride[cum_dim];
+ const size_t data_size =
+ tensor_desc.stride[0] * tensor_desc.shape[0] / cum_size;
+ CUDA_1D_KERNEL_LOOP(index, data_size) {
+ size_t cum_offset =
+ index / cum_stride * (cum_size * cum_stride) + index % cum_stride;
+ int cum_index = 0;
+ auto cum_value = input[cum_offset];
+ output_value[cum_offset] = cum_value;
+ output_index[cum_offset] = cum_index;
+
+ for (size_t cum_index_current = 1; cum_index_current < cum_size;
+ ++cum_index_current) {
+ cum_offset += cum_stride;
+ const auto cum_value_current = input[cum_offset];
+ switch (cum_type) {
+ case 0: // max
+ if (cum_value_current > cum_value) {
+ cum_value = cum_value_current;
+ cum_index = cum_index_current;
+ }
+ break;
+ case 1: // min
+ if (cum_value_current < cum_value) {
+ cum_value = cum_value_current;
+ cum_index = cum_index_current;
+ }
+ break;
+ }
+ output_value[cum_offset] = cum_value;
+ output_index[cum_offset] = cum_index;
+ }
+ }
+}
+
+template
+void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value,
+ int *output_index, const int *dims, int nbDims,
+ int cum_dim, int cum_type, cudaStream_t stream) {
+ // fill tensordesc and initial
+ TensorDesc tensor_desc;
+ memset((void *)&tensor_desc, 0, sizeof(TensorDesc));
+ tensor_desc.dim = nbDims;
+ tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
+ tensor_desc.stride[nbDims - 1] = 1;
+ for (int i = nbDims - 2; i >= 0; --i) {
+ tensor_desc.shape[i] = dims[i];
+ tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
+ }
+
+ // cum dim should be larger than 0
+ cum_dim = cum_dim >= 0 ? cum_dim : (nbDims + cum_dim);
+
+ const int data_size =
+ tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim];
+
+ const int col_block = DIVUP(data_size, THREADS_PER_BLOCK);
+
+ cummaxmin_kernel<<>>(
+ input, output_value, output_index, tensor_desc, cum_dim, cum_type);
+}
+
+void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream) {
+ CumMaxMinForwardLauncher(input, output_value, output_index, dims,
+ nbDims, cum_dim, cum_type, stream);
+}
+
+void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream) {
+ CumMaxMinForwardLauncher(input, output_value, output_index, dims, nbDims,
+ cum_dim, cum_type, stream);
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp
index 988e9bc46e..fa008e4190 100644
--- a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp
@@ -32,9 +32,7 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(
mDilation(dilation),
mDeformableGroup(deformableGroup),
mGroup(group),
- mIm2colStep(im2colStep) {
- cublasCreate(&m_cublas_handle);
-}
+ mIm2colStep(im2colStep) {}
DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
const void *data,
@@ -46,12 +44,8 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
deserialize_value(&data, &length, &mDeformableGroup);
deserialize_value(&data, &length, &mGroup);
deserialize_value(&data, &length, &mIm2colStep);
- cublasCreate(&m_cublas_handle);
-}
-DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {
- // destroy cublas handle
- cublasDestroy(m_cublas_handle);
}
+DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const {
DeformableConvPluginDynamic *plugin =
@@ -127,11 +121,6 @@ int DeformableConvPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
- if (m_cuda_stream != stream) {
- cublasSetStream(m_cublas_handle, stream);
- m_cuda_stream = stream;
- }
-
int batch_size = inputDesc[0].dims.d[0];
int inputChannel = inputDesc[0].dims.d[1];
int inputHeight = inputDesc[0].dims.d[2];
@@ -204,6 +193,14 @@ void DeformableConvPluginDynamic::destroy() {
delete this;
}
+void DeformableConvPluginDynamic::attachToContext(
+ cudnnContext *cudnnContext, cublasContext *cublasContext,
+ nvinfer1::IGpuAllocator *gpuAllocator) {
+ m_cublas_handle = cublasContext;
+}
+
+void DeformableConvPluginDynamic::detachFromContext() {}
+
void DeformableConvPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu
index 36a63dea9d..b5eefa6e71 100644
--- a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu
@@ -1,4 +1,3 @@
-#include
#include
#include "common_cuda_helper.hpp"
@@ -32,38 +31,6 @@ void trt_deformable_im2col(const T* data_input, const T* data_offset,
cudaCheckError();
}
-// used to switch gemm between fp32 and fp16
-template
-cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
- cublasOperation_t transb, int m, int n, int k,
- const scalar_t* alpha, const scalar_t* A, int lda,
- const scalar_t* B, int ldb, const scalar_t* beta,
- scalar_t* C, int ldc) {
- return CUBLAS_STATUS_INTERNAL_ERROR;
-}
-
-template <>
-cublasStatus_t cublasGemmWrap(cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb, int m, int n,
- int k, const float* alpha, const float* A,
- int lda, const float* B, int ldb,
- const float* beta, float* C, int ldc) {
- cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
- ldc);
-}
-
-template <>
-cublasStatus_t cublasGemmWrap(cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb, int m, int n,
- int k, const half* alpha, const half* A,
- int lda, const half* B, int ldb,
- const half* beta, half* C, int ldc) {
- cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
- ldc);
-}
-
template
void DeformConvForwardCUDAKernelLauncher(
const scalar_t* input, const scalar_t* weight, const scalar_t* offset,
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp
new file mode 100644
index 0000000000..1efdcb3a8d
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp
@@ -0,0 +1,245 @@
+// Modified from:
+// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.cpp
+
+#include "trt_instance_norm.hpp"
+
+#include
+
+#include
+
+#include "trt_serialize.hpp"
+
+using namespace nvinfer1;
+
+cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
+ cudnnDataType_t* cudnn_dtype) {
+ switch (trt_dtype) {
+ case nvinfer1::DataType::kFLOAT:
+ *cudnn_dtype = CUDNN_DATA_FLOAT;
+ break;
+ case nvinfer1::DataType::kHALF:
+ *cudnn_dtype = CUDNN_DATA_HALF;
+ break;
+ default:
+ return CUDNN_STATUS_BAD_PARAM;
+ }
+ return CUDNN_STATUS_SUCCESS;
+}
+
+namespace {
+constexpr const char* PLUGIN_VERSION{"1"};
+constexpr const char* PLUGIN_NAME{"MMCVInstanceNormalization"};
+} // namespace
+
+PluginFieldCollection InstanceNormalizationDynamicCreator::mFC{};
+std::vector InstanceNormalizationDynamicCreator::mPluginAttributes;
+
+InstanceNormalizationDynamic::InstanceNormalizationDynamic(
+ const std::string& name, float epsilon)
+ : mLayerName(name), mEpsilon(epsilon) {}
+
+InstanceNormalizationDynamic::InstanceNormalizationDynamic(
+ const std::string& name, void const* serialData, size_t serialLength)
+ : mLayerName(name) {
+ deserialize_value(&serialData, &serialLength, &mEpsilon);
+}
+
+InstanceNormalizationDynamic::~InstanceNormalizationDynamic() {}
+
+// InstanceNormalizationDynamic returns one output.
+int InstanceNormalizationDynamic::getNbOutputs() const { return 1; }
+
+DimsExprs InstanceNormalizationDynamic::getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) {
+ nvinfer1::DimsExprs output(inputs[0]);
+ return output;
+}
+
+int InstanceNormalizationDynamic::initialize() { return 0; }
+
+void InstanceNormalizationDynamic::terminate() {}
+
+size_t InstanceNormalizationDynamic::getWorkspaceSize(
+ const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
+ const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
+ int n = inputs[0].dims.d[0];
+ int c = inputs[0].dims.d[1];
+ int elem_size = mmcv::getElementSize(inputs[1].type);
+ return mmcv::getAlignedSize(n * c * elem_size) * 2;
+}
+
+int InstanceNormalizationDynamic::enqueue(
+ const nvinfer1::PluginTensorDesc* inputDesc,
+ const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
+ void* const* outputs, void* workspace, cudaStream_t stream) {
+ nvinfer1::Dims input_dims = inputDesc[0].dims;
+ int n = input_dims.d[0];
+ int c = input_dims.d[1];
+ int h = input_dims.d[2];
+ int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1;
+ int elem_size = mmcv::getElementSize(inputDesc[1].type);
+
+ void* n_scales = (void*)workspace;
+ void* n_bias = (void*)(workspace + mmcv::getAlignedSize(n * c * elem_size));
+
+ const void* scales = (const void*)inputs[1];
+ const void* bias = (const void*)inputs[2];
+
+ for (int i = 0; i < n; ++i) {
+ cudaMemcpyAsync(n_scales + i * c * elem_size, scales, c * elem_size,
+ cudaMemcpyDeviceToDevice, stream);
+ cudaMemcpyAsync(n_bias + i * c * elem_size, bias, c * elem_size,
+ cudaMemcpyDeviceToDevice, stream);
+ }
+
+ cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1,
+ n * c, 1, 1);
+ cudnnDataType_t cudnn_dtype{};
+ convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype);
+ cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c,
+ h, w);
+ cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c,
+ h, w);
+ float alpha = 1;
+ float beta = 0;
+ void const* x_ptr = inputs[0];
+ void* y_ptr = outputs[0];
+ cudnnSetStream(_cudnn_handle, stream);
+ // Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical
+ // overflows (NaNs) for fp32 data in some circumstances. The lower-
+ // performance CUDNN_BATCHNORM_SPATIAL should be used if this is not
+ // acceptable.
+ cudnnBatchNormalizationForwardTraining(
+ _cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, _x_desc,
+ x_ptr, _y_desc, y_ptr, _b_desc, n_scales, n_bias, 1., nullptr, nullptr,
+ mEpsilon, nullptr, nullptr);
+ return 0;
+}
+
+size_t InstanceNormalizationDynamic::getSerializationSize() const {
+ return serialized_size(mEpsilon);
+}
+
+void InstanceNormalizationDynamic::serialize(void* buffer) const {
+ serialize_value(&buffer, mEpsilon);
+}
+
+bool InstanceNormalizationDynamic::supportsFormatCombination(
+ int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
+ int nbOutputs) {
+ return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
+ inOut[pos].type == nvinfer1::DataType::kHALF) &&
+ inOut[pos].format == nvinfer1::PluginFormat::kLINEAR &&
+ inOut[pos].type == inOut[0].type);
+}
+
+const char* InstanceNormalizationDynamic::getPluginType() const {
+ return PLUGIN_NAME;
+}
+
+const char* InstanceNormalizationDynamic::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+void InstanceNormalizationDynamic::destroy() { delete this; }
+
+IPluginV2DynamicExt* InstanceNormalizationDynamic::clone() const {
+ auto* plugin = new InstanceNormalizationDynamic{mLayerName, mEpsilon};
+ plugin->setPluginNamespace(mPluginNamespace.c_str());
+ return plugin;
+}
+
+// Set plugin namespace
+void InstanceNormalizationDynamic::setPluginNamespace(
+ const char* pluginNamespace) {
+ mPluginNamespace = pluginNamespace;
+}
+
+const char* InstanceNormalizationDynamic::getPluginNamespace() const {
+ return mPluginNamespace.c_str();
+}
+
+nvinfer1::DataType InstanceNormalizationDynamic::getOutputDataType(
+ int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
+ return inputTypes[0];
+}
+
+// Attach the plugin object to an execution context and grant the plugin the
+// access to some context resource.
+void InstanceNormalizationDynamic::attachToContext(
+ cudnnContext* cudnnContext, cublasContext* cublasContext,
+ IGpuAllocator* gpuAllocator) {
+ _cudnn_handle = cudnnContext;
+ cudnnCreateTensorDescriptor(&_b_desc);
+ cudnnCreateTensorDescriptor(&_x_desc);
+ cudnnCreateTensorDescriptor(&_y_desc);
+}
+
+// Detach the plugin object from its execution context.
+void InstanceNormalizationDynamic::detachFromContext() {
+ cudnnDestroyTensorDescriptor(_y_desc);
+ cudnnDestroyTensorDescriptor(_x_desc);
+ cudnnDestroyTensorDescriptor(_b_desc);
+}
+
+void InstanceNormalizationDynamic::configurePlugin(
+ const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
+
+// InstanceNormalizationDynamicCreator methods
+InstanceNormalizationDynamicCreator::InstanceNormalizationDynamicCreator() {
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(
+ PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1));
+
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+const char* InstanceNormalizationDynamicCreator::getPluginName() const {
+ return PLUGIN_NAME;
+}
+
+const char* InstanceNormalizationDynamicCreator::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+const PluginFieldCollection*
+InstanceNormalizationDynamicCreator::getFieldNames() {
+ return &mFC;
+}
+
+IPluginV2DynamicExt* InstanceNormalizationDynamicCreator::createPlugin(
+ const char* name, const nvinfer1::PluginFieldCollection* fc) {
+ float epsilon = 1e-5;
+ const PluginField* fields = fc->fields;
+ for (int i = 0; i < fc->nbFields; ++i) {
+ const char* attrName = fields[i].name;
+ if (!strcmp(attrName, "epsilon")) {
+ epsilon = *(static_cast(fields[i].data));
+ }
+ }
+
+ InstanceNormalizationDynamic* obj =
+ new InstanceNormalizationDynamic(name, epsilon);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+}
+
+IPluginV2DynamicExt* InstanceNormalizationDynamicCreator::deserializePlugin(
+ const char* name, const void* serialData, size_t serialLength) {
+ InstanceNormalizationDynamic* obj =
+ new InstanceNormalizationDynamic{name, serialData, serialLength};
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+}
+
+void InstanceNormalizationDynamicCreator::setPluginNamespace(
+ const char* libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char* InstanceNormalizationDynamicCreator::getPluginNamespace() const {
+ return mNamespace.c_str();
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp
new file mode 100644
index 0000000000..88ab2cf67e
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp
@@ -0,0 +1,307 @@
+#include "trt_modulated_deform_conv.hpp"
+
+#include
+
+#include
+
+#include "trt_serialize.hpp"
+
+void ModulatedDeformConvForwardCUDAKernelLauncher_float(
+ const float *input, const float *weight, const float *bias,
+ const float *offset, const float *mask, float *output, void *workspace,
+ int batch, int channels, int height, int width, int channels_out,
+ int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w,
+ int pad_h, int dilation_w, int dilation_h, int group, int deformable_group,
+ int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream);
+
+namespace {
+static const char *PLUGIN_VERSION{"1"};
+static const char *PLUGIN_NAME{"MMCVModulatedDeformConv2d"};
+} // namespace
+
+nvinfer1::PluginFieldCollection
+ ModulatedDeformableConvPluginDynamicCreator::mFC{};
+std::vector
+ ModulatedDeformableConvPluginDynamicCreator::mPluginAttributes;
+
+ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(
+ const std::string &name, const nvinfer1::Dims stride,
+ const nvinfer1::Dims padding, const nvinfer1::Dims dilation,
+ const int deformableGroup, const int group)
+ : mLayerName(name),
+ mStride(stride),
+ mPadding(padding),
+ mDilation(dilation),
+ mDeformableGroup(deformableGroup),
+ mGroup(group) {
+ mWithBias = false;
+}
+
+ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(
+ const std::string name, const void *data, size_t length)
+ : mLayerName(name) {
+ deserialize_value(&data, &length, &mStride);
+ deserialize_value(&data, &length, &mPadding);
+ deserialize_value(&data, &length, &mDilation);
+ deserialize_value(&data, &length, &mDeformableGroup);
+ deserialize_value(&data, &length, &mGroup);
+ mWithBias = false;
+}
+ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {}
+
+nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone()
+ const {
+ ModulatedDeformableConvPluginDynamic *plugin =
+ new ModulatedDeformableConvPluginDynamic(
+ mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup);
+ plugin->setPluginNamespace(getPluginNamespace());
+
+ return plugin;
+}
+
+nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
+ nvinfer1::IExprBuilder &exprBuilder) {
+ nvinfer1::DimsExprs ret;
+ ret.nbDims = 4;
+ ret.d[0] = inputs[0].d[0];
+ ret.d[1] = inputs[3].d[0];
+
+ ret.d[2] = inputs[1].d[2];
+ ret.d[3] = inputs[1].d[3];
+
+ return ret;
+}
+
+bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination(
+ int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
+ int nbOutputs) {
+ if (pos == 0) {
+ return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
+ inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
+
+ } else {
+ return inOut[pos].type == inOut[0].type &&
+ inOut[pos].format == inOut[0].format;
+ }
+}
+
+void ModulatedDeformableConvPluginDynamic::configurePlugin(
+ const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {
+ if (nbInputs == 5) {
+ mWithBias = true;
+ }
+}
+
+size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize(
+ const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
+ const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
+ int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
+
+ int batch_size = inputs[0].dims.d[0];
+ int nInputPlane = inputs[0].dims.d[1];
+ int inputHeight = inputs[0].dims.d[2];
+ int inputWidth = inputs[0].dims.d[3];
+
+ int nOutputPlane = outputs[0].dims.d[1];
+ int outputHeight = outputs[0].dims.d[2];
+ int outputWidth = outputs[0].dims.d[3];
+
+ int kW = inputs[3].dims.d[2];
+ int kH = inputs[3].dims.d[3];
+ int im2col_step = std::min(32, batch_size);
+
+ size_t col_size = mmcv::getAlignedSize(nInputPlane * kW * kH * outputHeight *
+ outputWidth * sizeof_dtype);
+
+ return col_size;
+}
+
+int ModulatedDeformableConvPluginDynamic::enqueue(
+ const nvinfer1::PluginTensorDesc *inputDesc,
+ const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
+ void *const *outputs, void *workSpace, cudaStream_t stream) {
+ int batch = inputDesc[0].dims.d[0];
+ int channels = inputDesc[0].dims.d[1];
+ int height = inputDesc[0].dims.d[2];
+ int width = inputDesc[0].dims.d[3];
+ int channels_out = outputDesc[0].dims.d[1];
+ int kernel_h = inputDesc[3].dims.d[2];
+ int kernel_w = inputDesc[3].dims.d[3];
+
+ const void *x = inputs[0];
+ const void *offset = inputs[1];
+ const void *mask = inputs[2];
+ const void *weight = inputs[3];
+ const void *bias = mWithBias ? inputs[4] : nullptr;
+ void *output = outputs[0];
+ int im2col_step = std::min(batch, 32);
+
+ // TODO: add fp16 support
+ auto data_type = inputDesc[0].type;
+ switch (data_type) {
+ case nvinfer1::DataType::kFLOAT:
+ ModulatedDeformConvForwardCUDAKernelLauncher_float(
+ (float *)x, (float *)weight, (float *)bias, (float *)offset,
+ (float *)mask, (float *)output, workSpace, batch, channels, height,
+ width, channels_out, kernel_w, kernel_h, mStride.d[0], mStride.d[1],
+ mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup,
+ mDeformableGroup, im2col_step, m_cublas_handle, stream);
+ break;
+ default:
+ return 1;
+ break;
+ }
+
+ return 0;
+}
+
+nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType(
+ int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
+ return inputTypes[0];
+}
+
+// IPluginV2 Methods
+const char *ModulatedDeformableConvPluginDynamic::getPluginType() const {
+ return PLUGIN_NAME;
+}
+
+const char *ModulatedDeformableConvPluginDynamic::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+int ModulatedDeformableConvPluginDynamic::getNbOutputs() const { return 1; }
+
+int ModulatedDeformableConvPluginDynamic::initialize() { return 0; }
+
+void ModulatedDeformableConvPluginDynamic::terminate() {}
+
+size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const {
+ return sizeof(mStride) + sizeof(mPadding) + sizeof(mDilation) +
+ sizeof(mDeformableGroup) + sizeof(mGroup);
+}
+
+void ModulatedDeformableConvPluginDynamic::serialize(void *buffer) const {
+ serialize_value(&buffer, mStride);
+ serialize_value(&buffer, mPadding);
+ serialize_value(&buffer, mDilation);
+ serialize_value(&buffer, mDeformableGroup);
+ serialize_value(&buffer, mGroup);
+}
+
+void ModulatedDeformableConvPluginDynamic::destroy() {
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+void ModulatedDeformableConvPluginDynamic::attachToContext(
+ cudnnContext *cudnnContext, cublasContext *cublasContext,
+ nvinfer1::IGpuAllocator *gpuAllocator) {
+ m_cublas_handle = cublasContext;
+}
+
+void ModulatedDeformableConvPluginDynamic::detachFromContext() {}
+
+void ModulatedDeformableConvPluginDynamic::setPluginNamespace(
+ const char *libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char *ModulatedDeformableConvPluginDynamic::getPluginNamespace() const {
+ return mNamespace.c_str();
+}
+
+////////////////////// creator /////////////////////////////
+
+ModulatedDeformableConvPluginDynamicCreator::
+ ModulatedDeformableConvPluginDynamicCreator() {
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("stride"));
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("padding"));
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation"));
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("groups"));
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups"));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+const char *ModulatedDeformableConvPluginDynamicCreator::getPluginName() const {
+ return PLUGIN_NAME;
+}
+
+const char *ModulatedDeformableConvPluginDynamicCreator::getPluginVersion()
+ const {
+ return PLUGIN_VERSION;
+}
+
+const nvinfer1::PluginFieldCollection *
+ModulatedDeformableConvPluginDynamicCreator::getFieldNames() {
+ return &mFC;
+}
+
+nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin(
+ const char *name, const nvinfer1::PluginFieldCollection *fc) {
+ nvinfer1::Dims stride{2, {1, 1}};
+ nvinfer1::Dims padding{2, {0, 0}};
+ nvinfer1::Dims dilation{2, {1, 1}};
+ int deformableGroup = 1;
+ int group = 1;
+
+ for (int i = 0; i < fc->nbFields; i++) {
+ if (fc->fields[i].data == nullptr) {
+ continue;
+ }
+ std::string field_name(fc->fields[i].name);
+
+ if (field_name.compare("deformable_group") == 0) {
+ deformableGroup = static_cast(fc->fields[i].data)[0];
+ }
+
+ if (field_name.compare("group") == 0) {
+ group = static_cast(fc->fields[i].data)[0];
+ }
+
+ if (field_name.compare("stride") == 0) {
+ stride.nbDims = 2;
+ stride.d[0] = static_cast(fc->fields[i].data)[0];
+ stride.d[1] = static_cast(fc->fields[i].data)[1];
+ }
+
+ if (field_name.compare("padding") == 0) {
+ padding.nbDims = 2;
+ padding.d[0] = static_cast(fc->fields[i].data)[0];
+ padding.d[1] = static_cast(fc->fields[i].data)[1];
+ }
+
+ if (field_name.compare("dilation") == 0) {
+ dilation.nbDims = 2;
+ dilation.d[0] = static_cast(fc->fields[i].data)[0];
+ dilation.d[1] = static_cast(fc->fields[i].data)[1];
+ }
+ }
+
+ ModulatedDeformableConvPluginDynamic *plugin =
+ new ModulatedDeformableConvPluginDynamic(name, stride, padding, dilation,
+ deformableGroup, group);
+ plugin->setPluginNamespace(getPluginNamespace());
+ return plugin;
+}
+
+nvinfer1::IPluginV2 *
+ModulatedDeformableConvPluginDynamicCreator::deserializePlugin(
+ const char *name, const void *serialData, size_t serialLength) {
+ auto plugin =
+ new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength);
+ plugin->setPluginNamespace(getPluginNamespace());
+ return plugin;
+}
+
+void ModulatedDeformableConvPluginDynamicCreator::setPluginNamespace(
+ const char *libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char *ModulatedDeformableConvPluginDynamicCreator::getPluginNamespace()
+ const {
+ return mNamespace.c_str();
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu
new file mode 100644
index 0000000000..258ae783f6
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu
@@ -0,0 +1,133 @@
+#include
+#include
+
+#include "common_cuda_helper.hpp"
+#include "modulated_deform_conv_cuda_kernel.cuh"
+#include "trt_cuda_helper.cuh"
+#include "trt_plugin_helper.hpp"
+
+template
+void trt_modulated_deformable_im2col(
+ const T* data_im_, const T* data_offset_, const T* data_mask_,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, T* data_col_,
+ cudaStream_t stream) {
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ modulated_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im,
+ kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
+ dilation_w, channel_per_deformable_group, batch_size, channels,
+ deformable_group, height_col, width_col, data_col_);
+
+ cudaCheckError();
+}
+
+template
+__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias,
+ size_t step_batch, size_t step_channel,
+ size_t n) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ output[index] += bias[(index % step_batch) / step_channel];
+ }
+}
+
+template
+static void output_add_bias(scalar_t* output, const scalar_t* bias,
+ size_t batch, size_t channel, size_t height,
+ size_t width, cudaStream_t stream) {
+ size_t step_channel = height * width;
+ size_t step_batch = step_channel * channel;
+ size_t n = step_batch * batch;
+ output_add_bias_kernel<<>>(
+ output, bias, step_batch, step_channel, n);
+}
+
+template
+void ModulatedDeformConvForwardCUDAKernelLauncher(
+ const scalar_t* input, const scalar_t* weight, const scalar_t* bias,
+ const scalar_t* offset, const scalar_t* mask, scalar_t* output,
+ void* workspace, int batch, int channels, int height, int width,
+ int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h,
+ int pad_w, int pad_h, int dilation_w, int dilation_h, int group,
+ int deformable_group, int im2col_step, cublasHandle_t cublas_handle,
+ cudaStream_t stream) {
+ size_t sizeof_dtype = sizeof(scalar_t);
+ bool with_bias = (bias != nullptr);
+
+ im2col_step = std::min(int(batch), im2col_step);
+ assert(batch % im2col_step == 0);
+ const int channels_kernel = channels / group;
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ scalar_t* columns = (scalar_t*)workspace;
+
+ const size_t input_step = channels * height * width;
+ const size_t offset_step =
+ deformable_group * kernel_h * kernel_w * 2 * height * width;
+ const size_t mask_step =
+ deformable_group * kernel_h * kernel_w * height * width;
+ const size_t out_step = channels_out * height_out * width_out;
+ const size_t out_group_step = out_step / group;
+ const size_t col_g_step =
+ channels * kernel_w * kernel_h / group * height_out * width_out;
+ const size_t weight_g_step =
+ channels_out / group * channels / group * kernel_h * kernel_w;
+
+ const int m = channels_out / group;
+ const int n = height_out * width_out;
+ const int k = channels / group * kernel_h * kernel_w;
+ scalar_t alpha = 1.;
+ scalar_t beta = 0.;
+
+ for (int b = 0; b < batch; b++) {
+ const scalar_t* input_start = input + b * input_step;
+ const scalar_t* offset_start = offset + b * offset_step;
+ const scalar_t* mask_start = mask + b * mask_step;
+ trt_modulated_deformable_im2col(
+ input_start, offset_start, mask_start, 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, columns, stream);
+
+ for (int g = 0; g < group; g++) {
+ const scalar_t* weight_start = weight + g * weight_g_step;
+ scalar_t* col_start = columns + g * col_g_step;
+ scalar_t* out_buffer_start = output + b * out_step + g * out_group_step;
+
+ // cudaMemsetAsync(out_buffer_start, 0, 1, stream);
+ cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
+ &alpha, col_start, n, weight_start, k, &beta,
+ out_buffer_start, n);
+ cudaCheckError();
+ }
+ }
+
+ if (with_bias) {
+ output_add_bias(output, bias, batch, channels_out, height_out,
+ width_out, stream);
+ }
+}
+
+void ModulatedDeformConvForwardCUDAKernelLauncher_float(
+ const float* input, const float* weight, const float* bias,
+ const float* offset, const float* mask, float* output, void* workspace,
+ int batch, int channels, int height, int width, int channels_out,
+ int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w,
+ int pad_h, int dilation_w, int dilation_h, int group, int deformable_group,
+ int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) {
+ ModulatedDeformConvForwardCUDAKernelLauncher(
+ input, weight, bias, offset, mask, output, workspace, batch, channels,
+ height, width, channels_out, kernel_w, kernel_h, stride_w, stride_h,
+ pad_w, pad_h, dilation_w, dilation_h, group, deformable_group,
+ im2col_step, cublas_handle, stream);
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
index 06d034c365..c7b946b5dd 100644
--- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
@@ -1,16 +1,23 @@
#include "trt_plugin.hpp"
+#include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp"
+#include "trt_instance_norm.hpp"
+#include "trt_modulated_deform_conv.hpp"
#include "trt_nms.hpp"
#include "trt_roi_align.hpp"
#include "trt_scatternd.hpp"
+REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator);
+REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator);
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
+REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
+REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator);
extern "C" {
bool initLibMMCVInferPlugins() { return true; }
diff --git a/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh b/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh
index a4635dcdd5..db42dae9e1 100644
--- a/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh
+++ b/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh
@@ -1,5 +1,6 @@
#ifndef TRT_CUDA_HELPER_HPP
#define TRT_CUDA_HELPER_HPP
+#include
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
@@ -24,7 +25,16 @@
* @param[in] stream cuda stream handle
*/
template
-void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
- int *permute, int src_dim, cudaStream_t stream = 0);
+void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size,
+ int* permute, int src_dim, cudaStream_t stream = 0);
+
+template
+cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
+ cublasOperation_t transb, int m, int n, int k,
+ const scalar_t* alpha, const scalar_t* A, int lda,
+ const scalar_t* B, int ldb, const scalar_t* beta,
+ scalar_t* C, int ldc) {
+ return CUBLAS_STATUS_INTERNAL_ERROR;
+}
#endif // TRT_CUDA_HELPER_HPP
diff --git a/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp
new file mode 100644
index 0000000000..5b856b02fb
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp
@@ -0,0 +1,122 @@
+#ifndef TRT_CUMMAXMIN_HPP
+#define TRT_CUMMAXMIN_HPP
+#include
+#include
+
+#include "trt_plugin_helper.hpp"
+
+enum TRT_CUMCMPTYPE { TRT_CUMMAX = 0, TRT_CUMMIN = 1 };
+
+// implement of cummax and cummin
+class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
+ public:
+ CumMaxMinPluginDynamic(const std::string &name, int dim,
+ TRT_CUMCMPTYPE cumType);
+
+ CumMaxMinPluginDynamic(const std::string name, const void *data,
+ size_t length);
+
+ CumMaxMinPluginDynamic() = delete;
+
+ ~CumMaxMinPluginDynamic();
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt *clone() const override;
+ nvinfer1::DimsExprs getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
+ nvinfer1::IExprBuilder &exprBuilder) override;
+ bool supportsFormatCombination(int pos,
+ const nvinfer1::PluginTensorDesc *inOut,
+ int nbInputs, int nbOutputs) override;
+ void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
+ int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc *out,
+ int nbOutputs) override;
+ size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
+ int nbInputs,
+ const nvinfer1::PluginTensorDesc *outputs,
+ int nbOutputs) const override;
+ int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
+ const nvinfer1::PluginTensorDesc *outputDesc,
+ const void *const *inputs, void *const *outputs, void *workspace,
+ cudaStream_t stream) override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(int index,
+ const nvinfer1::DataType *inputTypes,
+ int nbInputs) const override;
+
+ // IPluginV2 Methods
+ const char *getPluginType() const override;
+ const char *getPluginVersion() const override;
+ int getNbOutputs() const override;
+ int initialize() override;
+ void terminate() override;
+ size_t getSerializationSize() const override;
+ void serialize(void *buffer) const override;
+ void destroy() override;
+ void setPluginNamespace(const char *pluginNamespace) override;
+ const char *getPluginNamespace() const override;
+
+ protected:
+ const std::string mLayerName;
+ std::string mNamespace;
+
+ int mDim;
+ TRT_CUMCMPTYPE mCumType;
+
+ protected:
+ // To prevent compiler warnings.
+ using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
+ using nvinfer1::IPluginV2DynamicExt::configurePlugin;
+ using nvinfer1::IPluginV2DynamicExt::enqueue;
+ using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
+ using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
+ using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
+ using nvinfer1::IPluginV2DynamicExt::supportsFormat;
+};
+
+// cummax and cummin creator
+class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator {
+ public:
+ CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE cumType);
+
+ const char *getPluginName() const override;
+
+ const char *getPluginVersion() const override;
+
+ const nvinfer1::PluginFieldCollection *getFieldNames() override;
+
+ nvinfer1::IPluginV2 *createPlugin(
+ const char *name, const nvinfer1::PluginFieldCollection *fc) override;
+
+ nvinfer1::IPluginV2 *deserializePlugin(const char *name,
+ const void *serialData,
+ size_t serialLength) override;
+
+ void setPluginNamespace(const char *pluginNamespace) override;
+
+ const char *getPluginNamespace() const override;
+
+ protected:
+ TRT_CUMCMPTYPE mCumType;
+ nvinfer1::PluginFieldCollection mFC;
+ std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+
+// cummax creator
+class CumMaxPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
+ public:
+ CumMaxPluginDynamicCreator();
+ const char *getPluginName() const override;
+};
+
+// cummin creator
+class CumMinPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
+ public:
+ CumMinPluginDynamicCreator();
+ const char *getPluginName() const override;
+};
+
+#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP
diff --git a/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp b/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp
index b8762f7868..fc48ac5dd9 100644
--- a/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp
+++ b/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp
@@ -44,6 +44,9 @@ class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override;
+ void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
+ nvinfer1::IGpuAllocator *gpuAllocator) override;
+ void detachFromContext() override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index,
@@ -74,7 +77,6 @@ class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
int mIm2colStep;
cublasHandle_t m_cublas_handle;
- cudaStream_t m_cuda_stream;
protected:
// To prevent compiler warnings.
diff --git a/mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp b/mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp
new file mode 100644
index 0000000000..78060c3901
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp
@@ -0,0 +1,120 @@
+// Modified from:
+// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.h
+
+#ifndef TRT_INSTANCE_NORMALIZATION_PLUGIN_H
+#define TRT_INSTANCE_NORMALIZATION_PLUGIN_H
+#include
+
+#include
+#include
+#include
+
+#include "trt_plugin_helper.hpp"
+
+typedef unsigned short half_type;
+
+class InstanceNormalizationDynamic final
+ : public nvinfer1::IPluginV2DynamicExt {
+ public:
+ InstanceNormalizationDynamic(const std::string& name, float epsilon);
+
+ InstanceNormalizationDynamic(const std::string& name, void const* serialData,
+ size_t serialLength);
+
+ InstanceNormalizationDynamic() = delete;
+
+ ~InstanceNormalizationDynamic() override;
+
+ int getNbOutputs() const override;
+
+ // DynamicExt plugins returns DimsExprs class instead of Dims
+ nvinfer1::DimsExprs getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) override;
+
+ int initialize() override;
+
+ void terminate() override;
+
+ size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
+ int nbInputs,
+ const nvinfer1::PluginTensorDesc* outputs,
+ int nbOutputs) const override;
+
+ int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
+ const nvinfer1::PluginTensorDesc* outputDesc,
+ const void* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) override;
+
+ size_t getSerializationSize() const override;
+
+ void serialize(void* buffer) const override;
+
+ // DynamicExt plugin supportsFormat update.
+ bool supportsFormatCombination(int pos,
+ const nvinfer1::PluginTensorDesc* inOut,
+ int nbInputs, int nbOutputs) override;
+
+ const char* getPluginType() const override;
+
+ const char* getPluginVersion() const override;
+
+ void destroy() override;
+
+ nvinfer1::IPluginV2DynamicExt* clone() const override;
+
+ void setPluginNamespace(const char* pluginNamespace) override;
+
+ const char* getPluginNamespace() const override;
+
+ nvinfer1::DataType getOutputDataType(int index,
+ const nvinfer1::DataType* inputTypes,
+ int nbInputs) const override;
+
+ void attachToContext(cudnnContext* cudnn, cublasContext* cublas,
+ nvinfer1::IGpuAllocator* allocator) override;
+
+ void detachFromContext() override;
+
+ void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
+ int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc* out,
+ int nbOutputs) override;
+
+ private:
+ const std::string mLayerName;
+ float mEpsilon{};
+ cudnnHandle_t _cudnn_handle{};
+ cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{};
+ std::string mPluginNamespace{};
+};
+
+class InstanceNormalizationDynamicCreator : public nvinfer1::IPluginCreator {
+ public:
+ InstanceNormalizationDynamicCreator();
+
+ ~InstanceNormalizationDynamicCreator() override = default;
+
+ const char* getPluginName() const override;
+
+ const char* getPluginVersion() const override;
+
+ const nvinfer1::PluginFieldCollection* getFieldNames() override;
+
+ nvinfer1::IPluginV2DynamicExt* createPlugin(
+ const char* name, const nvinfer1::PluginFieldCollection* fc) override;
+
+ nvinfer1::IPluginV2DynamicExt* deserializePlugin(
+ const char* name, const void* serialData, size_t serialLength) override;
+
+ void setPluginNamespace(const char* pluginNamespace) override;
+
+ const char* getPluginNamespace() const override;
+
+ private:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+
+#endif // TRT_INSTANCE_NORMALIZATION_PLUGIN_H
diff --git a/mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp b/mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp
new file mode 100644
index 0000000000..0907e7ea85
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp
@@ -0,0 +1,120 @@
+#ifndef TRT_MODULATED_DEFORM_CONV_HPP
+#define TRT_MODULATED_DEFORM_CONV_HPP
+#include
+
+#include
+#include
+#include
+
+#include "trt_plugin_helper.hpp"
+
+class ModulatedDeformableConvPluginDynamic
+ : public nvinfer1::IPluginV2DynamicExt {
+ public:
+ ModulatedDeformableConvPluginDynamic(const std::string &name,
+ const nvinfer1::Dims stride,
+ const nvinfer1::Dims padding,
+ const nvinfer1::Dims dilation,
+ const int deformableGroup,
+ const int group);
+
+ ModulatedDeformableConvPluginDynamic(const std::string name, const void *data,
+ size_t length);
+
+ ModulatedDeformableConvPluginDynamic() = delete;
+
+ ~ModulatedDeformableConvPluginDynamic();
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt *clone() const override;
+ nvinfer1::DimsExprs getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
+ nvinfer1::IExprBuilder &exprBuilder) override;
+ bool supportsFormatCombination(int pos,
+ const nvinfer1::PluginTensorDesc *inOut,
+ int nbInputs, int nbOutputs) override;
+ void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
+ int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc *out,
+ int nbOutputs) override;
+ size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
+ int nbInputs,
+ const nvinfer1::PluginTensorDesc *outputs,
+ int nbOutputs) const override;
+ int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
+ const nvinfer1::PluginTensorDesc *outputDesc,
+ const void *const *inputs, void *const *outputs, void *workspace,
+ cudaStream_t stream) override;
+ void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
+ nvinfer1::IGpuAllocator *gpuAllocator) override;
+ void detachFromContext() override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(int index,
+ const nvinfer1::DataType *inputTypes,
+ int nbInputs) const override;
+
+ // IPluginV2 Methods
+ const char *getPluginType() const override;
+ const char *getPluginVersion() const override;
+ int getNbOutputs() const override;
+ int initialize() override;
+ void terminate() override;
+ size_t getSerializationSize() const override;
+ void serialize(void *buffer) const override;
+ void destroy() override;
+ void setPluginNamespace(const char *pluginNamespace) override;
+ const char *getPluginNamespace() const override;
+
+ private:
+ const std::string mLayerName;
+ std::string mNamespace;
+
+ nvinfer1::Dims mStride;
+ nvinfer1::Dims mPadding;
+ nvinfer1::Dims mDilation;
+ int mDeformableGroup;
+ int mGroup;
+ bool mWithBias;
+
+ cublasHandle_t m_cublas_handle;
+
+ protected:
+ // To prevent compiler warnings.
+ using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
+ using nvinfer1::IPluginV2DynamicExt::configurePlugin;
+ using nvinfer1::IPluginV2DynamicExt::enqueue;
+ using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
+ using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
+ using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
+ using nvinfer1::IPluginV2DynamicExt::supportsFormat;
+};
+
+class ModulatedDeformableConvPluginDynamicCreator
+ : public nvinfer1::IPluginCreator {
+ public:
+ ModulatedDeformableConvPluginDynamicCreator();
+
+ const char *getPluginName() const override;
+
+ const char *getPluginVersion() const override;
+
+ const nvinfer1::PluginFieldCollection *getFieldNames() override;
+
+ nvinfer1::IPluginV2 *createPlugin(
+ const char *name, const nvinfer1::PluginFieldCollection *fc) override;
+
+ nvinfer1::IPluginV2 *deserializePlugin(const char *name,
+ const void *serialData,
+ size_t serialLength) override;
+
+ void setPluginNamespace(const char *pluginNamespace) override;
+
+ const char *getPluginNamespace() const override;
+
+ private:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+#endif // TRT_MODULATED_DEFORM_CONV_HPP
diff --git a/mmcv/ops/csrc/tensorrt/trt_serialize.hpp b/mmcv/ops/csrc/tensorrt/trt_serialize.hpp
index c9e75cbbe7..1f0899fdfe 100644
--- a/mmcv/ops/csrc/tensorrt/trt_serialize.hpp
+++ b/mmcv/ops/csrc/tensorrt/trt_serialize.hpp
@@ -1,18 +1,6 @@
-/*
- * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
+// Modified from:
+// https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp
+
#ifndef TRT_SERIALIZE_HPP
#define TRT_SERIALIZE_HPP
#include
diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py
index 5282e26193..04666f58db 100644
--- a/mmcv/ops/deform_conv.py
+++ b/mmcv/ops/deform_conv.py
@@ -70,8 +70,14 @@ def forward(ctx,
ctx.deform_groups = deform_groups
ctx.im2col_step = im2col_step
- # until the code is modified for torch.cuda.amp.autocast,
- # we need to cast weight to avoid type mismatch in fp16 training
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
weight = weight.type_as(input)
ctx.save_for_backward(input, offset, weight)
diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py
index c2bf7b4f00..52c392dc7e 100644
--- a/mmcv/ops/fused_bias_leakyrelu.py
+++ b/mmcv/ops/fused_bias_leakyrelu.py
@@ -195,15 +195,15 @@ class FusedBiasLeakyReLU(nn.Module):
The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a
- scale similarly with Kaiming initalization. However, since the
+ scale similarly with Kaiming initialization. However, since the
:math:`1 + \alpha^2` : is too small, we can just ignore it. Therefore, the
- final sacle is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
your own scale.
TODO: Implement the CPU version.
Args:
- channel (int): The channnel number of the feature map.
+ channel (int): The channel number of the feature map.
negative_slope (float, optional): Same as nn.LeakyRelu.
Defaults to 0.2.
scale (float, optional): A scalar to adjust the variance of the feature
@@ -230,9 +230,9 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a
- scale similarly with Kaiming initalization. However, since the
+ scale similarly with Kaiming initialization. However, since the
:math:`1 + \alpha^2` : is too small, we can just ignore it. Therefore, the
- final sacle is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
your own scale.
Args:
diff --git a/mmcv/ops/merge_cells.py b/mmcv/ops/merge_cells.py
index b881026c45..e3b1775099 100644
--- a/mmcv/ops/merge_cells.py
+++ b/mmcv/ops/merge_cells.py
@@ -10,7 +10,7 @@
class BaseMergeCell(nn.Module):
"""The basic class for cells used in NAS-FPN and NAS-FCOS.
- BaseMergeCell takes 2 inputs. After applying concolution
+ BaseMergeCell takes 2 inputs. After applying convolution
on them, they are resized to the target size. Then,
they go through binary_op, which depends on the type of cell.
If with_out_conv is True, the result of output will go through
diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py
index b8ff1adeb2..d26f61a0a1 100644
--- a/mmcv/ops/modulated_deform_conv.py
+++ b/mmcv/ops/modulated_deform_conv.py
@@ -20,13 +20,12 @@ class ModulatedDeformConv2dFunction(Function):
@staticmethod
def symbolic(g, input, offset, mask, weight, bias, stride, padding,
dilation, groups, deform_groups):
+ input_tensors = [input, offset, mask, weight]
+ if bias is not None:
+ input_tensors.append(bias)
return g.op(
- 'MMCVModulatedDeformConv2d',
- input,
- offset,
- mask,
- weight,
- bias,
+ 'mmcv::MMCVModulatedDeformConv2d',
+ *input_tensors,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
@@ -57,6 +56,15 @@ def forward(ctx,
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py
index 77919e47ec..45b22468a4 100644
--- a/mmcv/ops/multi_scale_deform_attn.py
+++ b/mmcv/ops/multi_scale_deform_attn.py
@@ -1,7 +1,15 @@
+import math
+import warnings
+
import torch
+import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
+from mmcv import deprecated_api_warning
+from mmcv.cnn import constant_init, xavier_init
+from mmcv.cnn.bricks.registry import ATTENTION
+from mmcv.runner import BaseModule
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
@@ -35,11 +43,13 @@ def forward(ctx, value, value_spatial_shapes, value_level_start_index,
"""
ctx.im2col_step = im2col_step
- output = ext_module.ms_deform_attn_forward(value, value_spatial_shapes,
- value_level_start_index,
- sampling_locations,
- attention_weights,
- ctx.im2col_step)
+ output = ext_module.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
@@ -60,15 +70,21 @@ def backward(ctx, grad_output):
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
- grad_value, grad_sampling_loc, grad_attn_weight = \
- ext_module.ms_deform_attn_backward(
- value,
- value_spatial_shapes,
- value_level_start_index,
- sampling_locations,
- attention_weights,
- grad_output,
- ctx.im2col_step)
+ grad_value = torch.zeros_like(value)
+ grad_sampling_loc = torch.zeros_like(sampling_locations)
+ grad_attn_weight = torch.zeros_like(attention_weights)
+
+ ext_module.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output.contiguous(),
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
@@ -132,3 +148,211 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
+
+
+@ATTENTION.register_module()
+class MultiScaleDeformableAttention(BaseModule):
+ """An attention module used in Deformable-Detr. `Deformable DETR:
+ Deformable Transformers for End-to-End Object Detection.
+
+ `_.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention.
+ Default: 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in
+ Attention. Default: 4.
+ num_points (int): The number of sampling points for
+ each query in each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`.
+ Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.1,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ if embed_dims % num_heads != 0:
+ raise ValueError(f'embed_dims must be divisible by num_heads, '
+ f'but got {embed_dims} and {num_heads}')
+ dim_per_head = embed_dims // num_heads
+ self.norm_cfg = norm_cfg
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+
+ # you'd better set dim_per_head to a power of 2
+ # which is more efficient in the CUDA implementation
+ def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ 'invalid input for _is_power_of_2: {} (type: {})'.format(
+ n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+ if not _is_power_of_2(dim_per_head):
+ warnings.warn(
+ "You'd better set embed_dims in "
+ 'MultiScaleDeformAttention to make '
+ 'the dimension of each attention head a power of 2 '
+ 'which is more efficient in our CUDA implementation.')
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2)
+ self.attention_weights = nn.Linear(embed_dims,
+ num_heads * num_levels * num_points)
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+
+ def init_weights(self):
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.)
+ thetas = torch.arange(
+ self.num_heads,
+ dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init /
+ grid_init.abs().max(-1, keepdim=True)[0]).view(
+ self.num_heads, 1, 1,
+ 2).repeat(1, self.num_levels, self.num_points, 1)
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0., bias=0.)
+ xavier_init(self.value_proj, distribution='uniform', bias=0.)
+ xavier_init(self.output_proj, distribution='uniform', bias=0.)
+ self._is_init = True
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiScaleDeformableAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_padding_mask=None,
+ reference_points=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ **kwargs):
+ """Forward Function of MultiScaleDeformAttention.
+
+ Args:
+ query (Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ key (Tensor): The key tensor with shape
+ `(num_key, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_key, bs, embed_dims)`.
+ identity (Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`. Default
+ None.
+ reference_points (Tensor): The normalized reference
+ points with shape (bs, num_query, num_levels, 2),
+ all elements is range in [0, 1], top-left (0,0),
+ bottom-right (1, 1), including padding area.
+ or (N, Length_{query}, num_levels, 4), add
+ additional two dimensions is (w, h) to
+ form reference boxes.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ spatial_shapes (Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+
+ if value is None:
+ value = query
+
+ if identity is None:
+ identity = query
+ if query_pos is not None:
+ query = query + query_pos
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points)
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(bs, num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points)
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets \
+ / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.num_points \
+ * reference_points[:, :, None, :, None, 2:] \
+ * 0.5
+ else:
+ raise ValueError(
+ f'Last dim of reference_points must be'
+ f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+ if torch.cuda.is_available():
+ output = MultiScaleDeformableAttnFunction.apply(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ # (num_query, bs ,embed_dims)
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py
index ef16a55425..0d2467a0d9 100644
--- a/mmcv/ops/nms.py
+++ b/mmcv/ops/nms.py
@@ -1,5 +1,4 @@
import os
-import sys
import numpy as np
import torch
@@ -15,13 +14,27 @@
class NMSop(torch.autograd.Function):
@staticmethod
- def forward(ctx, bboxes, scores, iou_threshold, offset):
+ def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ is_filtering_by_score = score_threshold > 0
+ if is_filtering_by_score:
+ valid_mask = scores > score_threshold
+ bboxes, scores = bboxes[valid_mask], scores[valid_mask]
+ valid_inds = torch.nonzero(
+ valid_mask, as_tuple=False).squeeze(dim=1)
+
inds = ext_module.nms(
bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
+
+ if max_num > 0:
+ inds = inds[:max_num]
+ if is_filtering_by_score:
+ inds = valid_inds[inds]
return inds
@staticmethod
- def symbolic(g, bboxes, scores, iou_threshold, offset):
+ def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded()
# TensorRT nms plugin is aligned with original nms in ONNXRuntime
@@ -35,16 +48,28 @@ def symbolic(g, bboxes, scores, iou_threshold, offset):
offset_i=int(offset))
else:
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
+ from ..onnx.onnx_utils.symbolic_helper import _size_helper
+
boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
- max_output_per_class = g.op(
- 'Constant',
- value_t=torch.tensor([sys.maxsize], dtype=torch.long))
+
+ if max_num > 0:
+ max_num = g.op(
+ 'Constant',
+ value_t=torch.tensor(max_num, dtype=torch.long))
+ else:
+ dim = g.op('Constant', value_t=torch.tensor(0))
+ max_num = _size_helper(g, bboxes, dim)
+ max_output_per_class = max_num
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
+ score_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([score_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores,
- max_output_per_class, iou_threshold)
+ max_output_per_class, iou_threshold,
+ score_threshold)
return squeeze(
g,
select(
@@ -90,7 +115,7 @@ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
@deprecated_api_warning({'iou_thr': 'iou_threshold'})
-def nms(boxes, scores, iou_threshold, offset=0):
+def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
"""Dispatch to either CPU or GPU NMS implementations.
The input can be either torch tensor or numpy array. GPU NMS will be used
@@ -102,6 +127,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
scores (torch.Tensor or np.ndarray): scores in shape (N, ).
iou_threshold (float): IoU threshold for NMS.
offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+ score_threshold (float): score threshold for NMS.
+ max_num (int): maximum number of boxes after NMS.
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the \
@@ -141,7 +168,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
}
inds = ext_module.nms(*indata_list, **indata_dict)
else:
- inds = NMSop.apply(boxes, scores, iou_threshold, offset)
+ inds = NMSop.apply(boxes, scores, iou_threshold, offset,
+ score_threshold, max_num)
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy:
dets = dets.cpu().numpy()
@@ -285,6 +313,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
# Some type of nms would reweight the score, such as SoftNMS
scores = dets[:, 4]
else:
+ max_num = nms_cfg_.pop('max_num', -1)
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores.new_zeros(scores.size())
@@ -294,10 +323,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
total_mask[mask[keep]] = True
scores_after_nms[mask[keep]] = dets[:, -1]
keep = total_mask.nonzero(as_tuple=False).view(-1)
+
scores, inds = scores_after_nms[keep].sort(descending=True)
keep = keep[inds]
boxes = boxes[keep]
+ if max_num > 0:
+ keep = keep[:max_num]
+ boxes = boxes[:max_num]
+ scores = scores[:max_num]
+
return torch.cat([boxes, scores[:, None]], -1), keep
@@ -350,7 +385,7 @@ def nms_rotated(dets, scores, iou_threshold, labels=None):
be in (x_ctr, y_ctr, width, height, angle_radian) format.
scores (Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
- labels (Tensor): boxes's label in shape (N,).
+ labels (Tensor): boxes' label in shape (N,).
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the \
diff --git a/mmcv/ops/pixel_group.py b/mmcv/ops/pixel_group.py
index 8361fa1e25..5aa5e0d7b2 100644
--- a/mmcv/ops/pixel_group.py
+++ b/mmcv/ops/pixel_group.py
@@ -14,7 +14,7 @@ def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
Arguments:
score (np.array or Tensor): The foreground score with size hxw.
mask (np.array or Tensor): The foreground mask with size hxw.
- embedding (np.array or Tensor): The emdedding with size hxwxc to
+ embedding (np.array or Tensor): The embedding with size hxwxc to
distinguish instances.
kernel_label (np.array or Tensor): The instance kernel index with
size hxw.
diff --git a/mmcv/ops/point_sample.py b/mmcv/ops/point_sample.py
index c5f59d3f18..c084a8c220 100644
--- a/mmcv/ops/point_sample.py
+++ b/mmcv/ops/point_sample.py
@@ -1,9 +1,94 @@
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+from os import path as osp
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
+from torch.onnx.operators import shape_as_tensor
+
+
+def bilinear_grid_sample(im, grid, align_corners=False):
+ """Given an input and a flow-field grid, computes the output using input
+ values and pixel locations from grid. Supported only bilinear interpolation
+ method to sample the input pixels.
+
+ Args:
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
+ considered as referring to the center points of the input’s
+ corner pixels. If set to False, they are instead considered as
+ referring to the corner points of the input’s corner pixels,
+ making the sampling more resolution agnostic.
+ Returns:
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+ """
+ n, c, h, w = im.shape
+ gn, gh, gw, _ = grid.shape
+ assert n == gn
+
+ x = grid[:, :, :, 0]
+ y = grid[:, :, :, 1]
+
+ if align_corners:
+ x = ((x + 1) / 2) * (w - 1)
+ y = ((y + 1) / 2) * (h - 1)
+ else:
+ x = ((x + 1) * w - 1) / 2
+ y = ((y + 1) * h - 1) / 2
+
+ x = x.view(n, -1)
+ y = y.view(n, -1)
+
+ x0 = torch.floor(x).long()
+ y0 = torch.floor(y).long()
+ x1 = x0 + 1
+ y1 = y0 + 1
+
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
+
+ # Apply default for grid_sample function zero padding
+ im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
+ padded_h = h + 2
+ padded_w = w + 2
+ # save points positions after padding
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+
+ # Clip coordinates to padded image size
+ x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+ x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+ y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+ y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+
+ im_padded = im_padded.view(n, c, -1)
+
+ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+
+ Ia = torch.gather(im_padded, 2, x0_y0)
+ Ib = torch.gather(im_padded, 2, x0_y1)
+ Ic = torch.gather(im_padded, 2, x1_y0)
+ Id = torch.gather(im_padded, 2, x1_y1)
+
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+
+
+def is_in_onnx_export_without_custom_ops():
+ from mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ return torch.onnx.is_in_onnx_export(
+ ) and not osp.exists(ort_custom_op_path)
def normalize(grid):
@@ -70,25 +155,42 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
if rois.size(1) == 5:
rois = rois[:, 1:]
abs_img_points = rel_roi_points.clone()
- abs_img_points[:, :, 0] = abs_img_points[:, :, 0] * (
- rois[:, None, 2] - rois[:, None, 0])
- abs_img_points[:, :, 1] = abs_img_points[:, :, 1] * (
- rois[:, None, 3] - rois[:, None, 1])
- abs_img_points[:, :, 0] += rois[:, None, 0]
- abs_img_points[:, :, 1] += rois[:, None, 1]
+ # To avoid an error during exporting to onnx use independent
+ # variables instead inplace computation
+ xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
+ ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
+ xs += rois[:, None, 0]
+ ys += rois[:, None, 1]
+ abs_img_points = torch.stack([xs, ys], dim=2)
return abs_img_points
-def abs_img_point_to_rel_img_point(abs_img_points,
- img_shape,
- spatial_scale=1.):
+def get_shape_from_feature_map(x):
+ """Get spatial resolution of input feature map considering exporting to
+ onnx mode.
+
+ Args:
+ x (torch.Tensor): Input tensor, shape (N, C, H, W)
+ Returns:
+ torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
+ """
+ if torch.onnx.is_in_onnx_export():
+ img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
+ x.device).float()
+ else:
+ img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
+ x.device).float()
+ return img_shape
+
+
+def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
"""Convert image based absolute point coordinates to image based relative
coordinates for sampling.
Args:
abs_img_points (Tensor): Image based absolute point coordinates,
shape (N, P, 2)
- img_shape (tuple): (height, width) of image or feature map.
+ img (tuple/Tensor): (height, width) of image or feature map.
spatial_scale (float): Scale points by this factor. Default: 1.
Returns:
@@ -96,20 +198,24 @@ def abs_img_point_to_rel_img_point(abs_img_points,
shape (N, P, 2)
"""
- assert isinstance(img_shape, tuple) and len(img_shape) == 2
- h, w = img_shape
- scale = torch.tensor([w, h],
- dtype=torch.float,
- device=abs_img_points.device)
- scale = scale.view(1, 1, 2)
- rel_img_points = abs_img_points / scale * spatial_scale
+ assert (isinstance(img, tuple) and len(img) == 2) or \
+ (isinstance(img, torch.Tensor) and len(img.shape) == 4)
- return rel_img_points
+ if isinstance(img, tuple):
+ h, w = img
+ scale = torch.tensor([w, h],
+ dtype=torch.float,
+ device=abs_img_points.device)
+ scale = scale.view(1, 1, 2)
+ else:
+ scale = get_shape_from_feature_map(img)
+
+ return abs_img_points / scale * spatial_scale
def rel_roi_point_to_rel_img_point(rois,
rel_roi_points,
- img_shape,
+ img,
spatial_scale=1.):
"""Convert roi based relative point coordinates to image based absolute
point coordinates.
@@ -118,7 +224,7 @@ def rel_roi_point_to_rel_img_point(rois,
rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
rel_roi_points (Tensor): Point coordinates inside RoI, relative to
RoI, location, range (0, 1), shape (N, P, 2)
- img_shape (tuple): (height, width) of image or feature map.
+ img (tuple/Tensor): (height, width) of image or feature map.
spatial_scale (float): Scale points by this factor. Default: 1.
Returns:
@@ -127,7 +233,7 @@ def rel_roi_point_to_rel_img_point(rois,
"""
abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
- rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img_shape,
+ rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
spatial_scale)
return rel_img_point
@@ -153,8 +259,15 @@ def point_sample(input, points, align_corners=False, **kwargs):
if points.dim() == 3:
add_dim = True
points = points.unsqueeze(2)
- output = F.grid_sample(
- input, denormalize(points), align_corners=align_corners, **kwargs)
+ if is_in_onnx_export_without_custom_ops():
+ # If custom ops for onnx runtime not compiled use python
+ # implementation of grid_sample function to make onnx graph
+ # with supported nodes
+ output = bilinear_grid_sample(
+ input, denormalize(points), align_corners=align_corners)
+ else:
+ output = F.grid_sample(
+ input, denormalize(points), align_corners=align_corners, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
@@ -181,29 +294,38 @@ def __init__(self, output_size, spatial_scale, aligned=True):
self.aligned = aligned
def forward(self, features, rois):
-
num_imgs = features.size(0)
num_rois = rois.size(0)
rel_roi_points = generate_grid(
num_rois, self.output_size, device=rois.device)
- point_feats = []
- for batch_ind in range(num_imgs):
- # unravel batch dim
- feat = features[batch_ind].unsqueeze(0)
- inds = (rois[:, 0].long() == batch_ind)
- if inds.any():
- rel_img_points = rel_roi_point_to_rel_img_point(
- rois[inds], rel_roi_points[inds], feat.shape[2:],
- self.spatial_scale).unsqueeze(0)
- point_feat = point_sample(
- feat, rel_img_points, align_corners=not self.aligned)
- point_feat = point_feat.squeeze(0).transpose(0, 1)
- point_feats.append(point_feat)
+ if torch.onnx.is_in_onnx_export():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, features, self.spatial_scale)
+ rel_img_points = rel_img_points.reshape(num_imgs, -1,
+ *rel_img_points.shape[1:])
+ point_feats = point_sample(
+ features, rel_img_points, align_corners=not self.aligned)
+ point_feats = point_feats.transpose(1, 2)
+ else:
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = features[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat,
+ self.spatial_scale).unsqueeze(0)
+ point_feat = point_sample(
+ feat, rel_img_points, align_corners=not self.aligned)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+
+ point_feats = torch.cat(point_feats, dim=0)
channels = features.size(1)
- roi_feats = torch.cat(point_feats, dim=0)
- roi_feats = roi_feats.reshape(num_rois, channels, *self.output_size)
+ roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
return roi_feats
diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py
index cd7eea122f..6b19ce5719 100644
--- a/mmcv/ops/saconv.py
+++ b/mmcv/ops/saconv.py
@@ -1,3 +1,5 @@
+from distutils.version import LooseVersion
+
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -98,13 +100,20 @@ def forward(self, x):
switch = self.switch(avg_x)
# sac
weight = self._get_weight(self.weight)
+ zero_bias = torch.zeros(
+ self.out_channels, device=weight.device, dtype=weight.dtype)
+
if self.use_deform:
offset = self.offset_s(avg_x)
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
- if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
+ if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
+ or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight)
+ elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_s = super()._conv_forward(x, weight, zero_bias)
else:
out_s = super()._conv_forward(x, weight)
ori_p = self.padding
@@ -117,10 +126,15 @@ def forward(self, x):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
- if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
+ if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
+ or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight)
+ elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_l = super()._conv_forward(x, weight, zero_bias)
else:
out_l = super()._conv_forward(x, weight)
+
out = switch * out_s + (1 - switch) * out_l
self.padding = ori_p
self.dilation = ori_d
diff --git a/mmcv/parallel/_functions.py b/mmcv/parallel/_functions.py
index 4cd02fbe67..ad19415f37 100644
--- a/mmcv/parallel/_functions.py
+++ b/mmcv/parallel/_functions.py
@@ -23,7 +23,7 @@ def scatter(input, devices, streams=None):
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
else:
- # unsquzee the first dimension thus the tensor's shape is the
+ # unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)
return output
diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py
index 767c4f9dd2..2882cf35d4 100644
--- a/mmcv/parallel/distributed.py
+++ b/mmcv/parallel/distributed.py
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
+from distutils.version import LooseVersion
+
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
@@ -37,7 +39,7 @@ def train_step(self, *inputs, **kwargs):
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
- if (TORCH_VERSION >= '1.7' and 'parrots'
+ if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
@@ -63,7 +65,7 @@ def train_step(self, *inputs, **kwargs):
else:
self.reducer.prepare_for_backward([])
else:
- if TORCH_VERSION > '1.2':
+ if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
@@ -77,7 +79,7 @@ def val_step(self, *inputs, **kwargs):
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
- if (TORCH_VERSION >= '1.7' and 'parrots'
+ if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
@@ -103,6 +105,6 @@ def val_step(self, *inputs, **kwargs):
else:
self.reducer.prepare_for_backward([])
else:
- if TORCH_VERSION > '1.2':
+ if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
diff --git a/mmcv/parallel/distributed_deprecated.py b/mmcv/parallel/distributed_deprecated.py
index 2a49fa9e3f..45443db995 100644
--- a/mmcv/parallel/distributed_deprecated.py
+++ b/mmcv/parallel/distributed_deprecated.py
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
+from distutils.version import LooseVersion
+
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -40,7 +42,7 @@ def _sync_params(self):
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
- if TORCH_VERSION < '1.0':
+ if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py
index 81dc4f0845..61d7b14d27 100644
--- a/mmcv/runner/__init__.py
+++ b/mmcv/runner/__init__.py
@@ -10,11 +10,11 @@
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
- DistSamplerSeedHook, EMAHook, EvalHook, Fp16OptimizerHook,
- Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
- MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
- SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
- WandbLoggerHook)
+ DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
+ Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
+ LrUpdaterHook, MlflowLoggerHook, NeptuneLoggerHook,
+ OptimizerHook, PaviLoggerHook, SyncBuffersHook,
+ TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
@@ -28,15 +28,16 @@
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
- 'WandbLoggerHook', 'MlflowLoggerHook', '_load_checkpoint',
- 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
- 'Priority', 'get_priority', 'get_host_info', 'get_time_str',
- 'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
- 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
- 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
- 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
- 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
- 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
- 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix',
- 'EvalHook', 'DistEvalHook', 'Sequential', 'ModuleList'
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
+ 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
+ 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
+ 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
+ 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
+ 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
+ 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
+ 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
+ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
+ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
+ '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
+ 'ModuleList'
]
diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py
index 38bf7dd61c..076316c0a1 100644
--- a/mmcv/runner/base_module.py
+++ b/mmcv/runner/base_module.py
@@ -22,7 +22,7 @@ def __init__(self, init_cfg=None):
super(BaseModule, self).__init__()
# define default value of init_cfg instead of hard code
- # in init_weigt() function
+ # in init_weight() function
self._is_init = False
self.init_cfg = init_cfg
diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py
index 6e8b299b41..1f1fa01845 100644
--- a/mmcv/runner/base_runner.py
+++ b/mmcv/runner/base_runner.py
@@ -14,7 +14,7 @@
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
-from .priority import get_priority
+from .priority import Priority, get_priority
from .utils import get_time_str
@@ -306,6 +306,29 @@ def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
+ def get_hook_info(self):
+ # Get hooks info in each stage
+ stage_hook_map = {stage: [] for stage in Hook.stages}
+ for hook in self.hooks:
+ try:
+ priority = Priority(hook.priority).name
+ except ValueError:
+ priority = hook.priority
+ classname = hook.__class__.__name__
+ hook_info = f'({priority:<12}) {classname:<35}'
+ for trigger_stage in hook.get_triggered_stages():
+ stage_hook_map[trigger_stage].append(hook_info)
+
+ stage_hook_infos = []
+ for stage in Hook.stages:
+ hook_infos = stage_hook_map[stage]
+ if len(hook_infos) > 0:
+ info = f'{stage}:\n'
+ info += '\n'.join(hook_infos)
+ info += '\n -------------------- '
+ stage_hook_infos.append(info)
+ return '\n'.join(stage_hook_infos)
+
def load_checkpoint(self,
filename,
map_location='cpu',
@@ -358,6 +381,9 @@ def resume(self,
self.logger.info('the iteration number is changed due to '
'change of GPU number')
+ # resume meta information meta
+ self.meta = checkpoint['meta']
+
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
@@ -391,7 +417,7 @@ def register_lr_hook(self, lr_config):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
- self.register_hook(hook, priority=10)
+ self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(self, momentum_config):
if momentum_config is None:
@@ -412,7 +438,7 @@ def register_momentum_hook(self, momentum_config):
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
- self.register_hook(hook, priority=30)
+ self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
@@ -422,7 +448,7 @@ def register_optimizer_hook(self, optimizer_config):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
- self.register_hook(hook, priority=50)
+ self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
@@ -432,7 +458,7 @@ def register_checkpoint_hook(self, checkpoint_config):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
- self.register_hook(hook, priority=70)
+ self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config):
if log_config is None:
@@ -441,7 +467,7 @@ def register_logger_hooks(self, log_config):
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
- self.register_hook(logger_hook, priority=90)
+ self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config):
if timer_config is None:
@@ -451,7 +477,7 @@ def register_timer_hook(self, timer_config):
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
- self.register_hook(hook, priority=80)
+ self.register_hook(hook, priority='LOW')
def register_custom_hooks(self, custom_config):
if custom_config is None:
@@ -488,14 +514,26 @@ def register_training_hooks(self,
Default and custom hooks include:
- Hooks Priority
- - LrUpdaterHook 10
- - MomentumUpdaterHook 30
- - OptimizerStepperHook 50
- - CheckpointSaverHook 70
- - IterTimerHook 80
- - LoggerHook(s) 90
- - CustomHook(s) 50 (default)
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py
index 0a9ccf35af..6221554b62 100644
--- a/mmcv/runner/dist_utils.py
+++ b/mmcv/runner/dist_utils.py
@@ -3,6 +3,7 @@
import os
import subprocess
from collections import OrderedDict
+from distutils.version import LooseVersion
import torch
import torch.multiprocessing as mp
@@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None):
def get_dist_info():
- if TORCH_VERSION < '1.0':
+ if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
initialized = dist._initialized
else:
if dist.is_available():
diff --git a/mmcv/runner/epoch_based_runner.py b/mmcv/runner/epoch_based_runner.py
index 1e1de295ed..baf072f18f 100644
--- a/mmcv/runner/epoch_based_runner.py
+++ b/mmcv/runner/epoch_based_runner.py
@@ -101,6 +101,8 @@ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
@@ -149,14 +151,17 @@ def save_checkpoint(self,
Defaults to True.
"""
if meta is None:
- meta = dict(epoch=self.epoch + 1, iter=self.iter)
- elif isinstance(meta, dict):
- meta.update(epoch=self.epoch + 1, iter=self.iter)
- else:
+ meta = {}
+ elif not isinstance(meta, dict):
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py
index 2f958fae1e..c5d562512e 100644
--- a/mmcv/runner/fp16_utils.py
+++ b/mmcv/runner/fp16_utils.py
@@ -1,6 +1,7 @@
import functools
import warnings
from collections import abc
+from distutils.version import LooseVersion
from inspect import getfullargspec
import numpy as np
@@ -31,7 +32,9 @@ def cast_tensor_type(inputs, src_type, dst_type):
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
- if isinstance(inputs, torch.Tensor):
+ if isinstance(inputs, nn.Module):
+ return inputs
+ elif isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
elif isinstance(inputs, str):
return inputs
@@ -119,7 +122,8 @@ def new_func(*args, **kwargs):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
- if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
+ if (TORCH_VERSION != 'parrots'
+ and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
@@ -204,7 +208,8 @@ def new_func(*args, **kwargs):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
- if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
+ if (TORCH_VERSION != 'parrots'
+ and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
@@ -243,7 +248,8 @@ def wrap_fp16_model(model):
Args:
model (nn.Module): Model in FP32.
"""
- if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0':
+ if (TORCH_VERSION == 'parrots'
+ or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')):
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
@@ -376,6 +382,29 @@ def update_scale(self, overflow):
self.cur_scale *= self.scale_factor
self.cur_iter += 1
+ def state_dict(self):
+ """Returns the state of the scaler as a :class:`dict`."""
+ return dict(
+ cur_scale=self.cur_scale,
+ cur_iter=self.cur_iter,
+ mode=self.mode,
+ last_overflow_iter=self.last_overflow_iter,
+ scale_factor=self.scale_factor,
+ scale_window=self.scale_window)
+
+ def load_state_dict(self, state_dict):
+ """Loads the loss_scaler state dict.
+
+ Args:
+ state_dict (dict): scaler state.
+ """
+ self.cur_scale = state_dict['cur_scale']
+ self.cur_iter = state_dict['cur_iter']
+ self.mode = state_dict['mode']
+ self.last_overflow_iter = state_dict['last_overflow_iter']
+ self.scale_factor = state_dict['scale_factor']
+ self.scale_window = state_dict['scale_window']
+
@property
def loss_scale(self):
return self.cur_scale
diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py
index caa4df6b8f..4f108ad4c3 100644
--- a/mmcv/runner/hooks/__init__.py
+++ b/mmcv/runner/hooks/__init__.py
@@ -5,8 +5,9 @@
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
-from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
- TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
+from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
+ NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
+ TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook
@@ -20,6 +21,7 @@
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
- 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook',
- 'EvalHook', 'DistEvalHook', 'ProfilerHook'
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
+ 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
+ 'DistEvalHook', 'ProfilerHook'
]
diff --git a/mmcv/runner/hooks/evaluation.py b/mmcv/runner/hooks/evaluation.py
index 151708de0e..5b8ab63f81 100644
--- a/mmcv/runner/hooks/evaluation.py
+++ b/mmcv/runner/hooks/evaluation.py
@@ -7,6 +7,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
+from mmcv.utils import is_seq_of
from .hook import Hook
@@ -41,6 +42,16 @@ class EvalHook(Hook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader, and return the test results. If ``None``, the default
+ test function ``mmcv.engine.single_gpu_test`` will be used.
+ (default: ``None``)
+ greater_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'greater' comparison rule rule. If ``None``,
+ _default_greater_keys will be used. (default: ``None``)
+ less_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
+ will be used. (default: ``None``)
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
@@ -55,8 +66,11 @@ class EvalHook(Hook):
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf}
- greater_keys = ['acc', 'top', 'AR@', 'auc', 'precision', 'mAP']
- less_keys = ['loss']
+ _default_greater_keys = [
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+ 'mAcc', 'aAcc'
+ ]
+ _default_less_keys = ['loss']
def __init__(self,
dataloader,
@@ -65,6 +79,9 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
@@ -92,6 +109,28 @@ def __init__(self,
self.eval_kwargs = eval_kwargs
self.initial_flag = True
+ if test_fn is None:
+ from mmcv.engine import single_gpu_test
+ self.test_fn = single_gpu_test
+ else:
+ self.test_fn = test_fn
+
+ if greater_keys is None:
+ self.greater_keys = self._default_greater_keys
+ else:
+ if not isinstance(greater_keys, (list, tuple)):
+ greater_keys = (greater_keys, )
+ assert is_seq_of(greater_keys, str)
+ self.greater_keys = greater_keys
+
+ if less_keys is None:
+ self.less_keys = self._default_less_keys
+ else:
+ if not isinstance(less_keys, (list, tuple)):
+ less_keys = (less_keys, )
+ assert is_seq_of(less_keys, str)
+ self.less_keys = less_keys
+
if self.save_best is not None:
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)
@@ -100,7 +139,8 @@ def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
- when the rule is not specific:
+ when the rule is not specific (note that the key indicator matching
+ is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
@@ -121,13 +161,19 @@ def _init_rule(self, rule, key_indicator):
if rule is None:
if key_indicator != 'auto':
- if key_indicator in self.greater_keys:
+ # `_lc` here means we use the lower case of keys for
+ # case-insensitive matching
+ key_indicator_lc = key_indicator.lower()
+ greater_keys = [key.lower() for key in self.greater_keys]
+ less_keys = [key.lower() for key in self.less_keys]
+
+ if key_indicator_lc in greater_keys:
rule = 'greater'
- elif key_indicator in self.less_keys:
+ elif key_indicator_lc in less_keys:
rule = 'less'
- elif any(key in key_indicator for key in self.greater_keys):
+ elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater'
- elif any(key in key_indicator for key in self.less_keys):
+ elif any(key in key_indicator_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for key '
@@ -178,8 +224,7 @@ def _do_evaluate(self, runner):
if not self._should_evaluate(runner):
return
- from mmcv.engine import single_gpu_test
- results = single_gpu_test(runner.model, self.dataloader)
+ results = self.test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
@@ -308,6 +353,10 @@ class DistEvalHook(EvalHook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader in a multi-gpu manner, and return the test results. If
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
+ will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
@@ -326,10 +375,18 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
+
+ if test_fn is None:
+ from mmcv.engine import multi_gpu_test
+ test_fn = multi_gpu_test
+
super().__init__(
dataloader,
start=start,
@@ -337,7 +394,11 @@ def __init__(self,
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
+ test_fn=test_fn,
+ greater_keys=greater_keys,
+ less_keys=less_keys,
**eval_kwargs)
+
self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
@@ -364,8 +425,7 @@ def _do_evaluate(self, runner):
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
- from mmcv.engine import multi_gpu_test
- results = multi_gpu_test(
+ results = self.test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
diff --git a/mmcv/runner/hooks/hook.py b/mmcv/runner/hooks/hook.py
index fa8ce4a49f..419f638c5e 100644
--- a/mmcv/runner/hooks/hook.py
+++ b/mmcv/runner/hooks/hook.py
@@ -1,10 +1,14 @@
# Copyright (c) Open-MMLab. All rights reserved.
-from mmcv.utils import Registry
+from mmcv.utils import Registry, is_method_overridden
HOOKS = Registry('hook')
class Hook:
+ stages = ('before_run', 'before_train_epoch', 'before_train_iter',
+ 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
+ 'before_val_iter', 'after_val_iter', 'after_val_epoch',
+ 'after_run')
def before_run(self, runner):
pass
@@ -65,3 +69,24 @@ def is_last_epoch(self, runner):
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
+
+ def get_triggered_stages(self):
+ trigger_stages = set()
+ for stage in Hook.stages:
+ if is_method_overridden(stage, Hook, self):
+ trigger_stages.add(stage)
+
+ # some methods will be triggered in multi stages
+ # use this dict to map method to stages.
+ method_stages_map = {
+ 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
+ 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
+ 'before_iter': ['before_train_iter', 'before_val_iter'],
+ 'after_iter': ['after_train_iter', 'after_val_iter'],
+ }
+
+ for method, map_stages in method_stages_map.items():
+ if is_method_overridden(method, Hook, self):
+ trigger_stages.update(map_stages)
+
+ return [stage for stage in Hook.stages if stage in trigger_stages]
diff --git a/mmcv/runner/hooks/logger/__init__.py b/mmcv/runner/hooks/logger/__init__.py
index 8fe4d81492..46beda07f7 100644
--- a/mmcv/runner/hooks/logger/__init__.py
+++ b/mmcv/runner/hooks/logger/__init__.py
@@ -1,6 +1,8 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook
+from .dvclive import DvcliveLoggerHook
from .mlflow import MlflowLoggerHook
+from .neptune import NeptuneLoggerHook
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
@@ -8,5 +10,6 @@
__all__ = [
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
- 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook'
+ 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
+ 'NeptuneLoggerHook', 'DvcliveLoggerHook'
]
diff --git a/mmcv/runner/hooks/logger/dvclive.py b/mmcv/runner/hooks/logger/dvclive.py
new file mode 100644
index 0000000000..336a652adc
--- /dev/null
+++ b/mmcv/runner/hooks/logger/dvclive.py
@@ -0,0 +1,58 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class DvcliveLoggerHook(LoggerHook):
+ """Class to log metrics with dvclive.
+
+ It requires `dvclive`_ to be installed.
+
+ Args:
+ path (str): Directory where dvclive will write TSV log files.
+ interval (int): Logging interval (every k iterations).
+ Default 10.
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ Default: True.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ Default: True.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ Default: True.
+
+ .. _dvclive:
+ https://dvc.org/doc/dvclive
+ """
+
+ def __init__(self,
+ path,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ by_epoch=True):
+
+ super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.path = path
+ self.import_dvclive()
+
+ def import_dvclive(self):
+ try:
+ import dvclive
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install dvclive" to install dvclive')
+ self.dvclive = dvclive
+
+ @master_only
+ def before_run(self, runner):
+ self.dvclive.init(self.path)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for k, v in tags.items():
+ self.dvclive.log(k, v, step=self.get_iter(runner))
diff --git a/mmcv/runner/hooks/logger/mlflow.py b/mmcv/runner/hooks/logger/mlflow.py
index 4967fec417..4e839340ef 100644
--- a/mmcv/runner/hooks/logger/mlflow.py
+++ b/mmcv/runner/hooks/logger/mlflow.py
@@ -13,7 +13,7 @@ def __init__(self,
log_model=True,
interval=10,
ignore_last=True,
- reset_flag=True,
+ reset_flag=False,
by_epoch=True):
"""Class to log metrics and (optionally) a trained model to MLflow.
@@ -60,6 +60,7 @@ def import_mlflow(self):
@master_only
def before_run(self, runner):
+ super(MlflowLoggerHook, self).before_run(runner)
if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name)
if self.tags is not None:
diff --git a/mmcv/runner/hooks/logger/neptune.py b/mmcv/runner/hooks/logger/neptune.py
new file mode 100644
index 0000000000..2e695863b1
--- /dev/null
+++ b/mmcv/runner/hooks/logger/neptune.py
@@ -0,0 +1,82 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class NeptuneLoggerHook(LoggerHook):
+ """Class to log metrics to NeptuneAI.
+
+ It requires `neptune-client` to be installed.
+
+ Args:
+ init_kwargs (dict): a dict contains the initialization keys as below:
+ - project (str): Name of a project in a form of
+ namespace/project_name. If None, the value of
+ NEPTUNE_PROJECT environment variable will be taken.
+ - api_token (str): User’s API token.
+ If None, the value of NEPTUNE_API_TOKEN environment
+ variable will be taken. Note: It is strongly recommended
+ to use NEPTUNE_API_TOKEN environment variable rather than
+ placing your API token in plain text in your source code.
+ - name (str, optional, default is 'Untitled'): Editable name of
+ the run. Name is displayed in the run's Details and in
+ Runs table as a column.
+ Check https://docs.neptune.ai/api-reference/neptune#init for
+ more init arguments.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _NeptuneAI:
+ https://docs.neptune.ai/you-should-know/logging-metadata
+ """
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ with_step=True,
+ by_epoch=True):
+
+ super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_neptune()
+ self.init_kwargs = init_kwargs
+ self.with_step = with_step
+
+ def import_neptune(self):
+ try:
+ import neptune.new as neptune
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install neptune-client" to install neptune')
+ self.neptune = neptune
+ self.run = None
+
+ @master_only
+ def before_run(self, runner):
+ if self.init_kwargs:
+ self.run = self.neptune.init(**self.init_kwargs)
+ else:
+ self.run = self.neptune.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for tag_name, tag_value in tags.items():
+ if self.with_step:
+ self.run[tag_name].log(
+ tag_value, step=self.get_iter(runner))
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.run[tag_name].log(tags)
+
+ @master_only
+ def after_run(self, runner):
+ self.run.stop()
diff --git a/mmcv/runner/hooks/logger/pavi.py b/mmcv/runner/hooks/logger/pavi.py
index 17c15b07b0..264d74abcd 100644
--- a/mmcv/runner/hooks/logger/pavi.py
+++ b/mmcv/runner/hooks/logger/pavi.py
@@ -22,7 +22,7 @@ def __init__(self,
add_last_ckpt=False,
interval=10,
ignore_last=True,
- reset_flag=True,
+ reset_flag=False,
by_epoch=True,
img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
@@ -34,6 +34,7 @@ def __init__(self,
@master_only
def before_run(self, runner):
+ super(PaviLoggerHook, self).before_run(runner)
try:
from pavi import SummaryWriter
except ImportError:
diff --git a/mmcv/runner/hooks/logger/tensorboard.py b/mmcv/runner/hooks/logger/tensorboard.py
index abb4ac4de5..475d4b5408 100644
--- a/mmcv/runner/hooks/logger/tensorboard.py
+++ b/mmcv/runner/hooks/logger/tensorboard.py
@@ -1,5 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
+from distutils.version import LooseVersion
from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only
@@ -14,7 +15,7 @@ def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
- reset_flag=True,
+ reset_flag=False,
by_epoch=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
@@ -22,7 +23,9 @@ def __init__(self,
@master_only
def before_run(self, runner):
- if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots':
+ super(TensorboardLoggerHook, self).before_run(runner)
+ if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1')
+ or TORCH_VERSION == 'parrots'):
try:
from tensorboardX import SummaryWriter
except ImportError:
diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py
index d43d1481b4..5b0c7f22f0 100644
--- a/mmcv/runner/hooks/logger/text.py
+++ b/mmcv/runner/hooks/logger/text.py
@@ -176,3 +176,4 @@ def log(self, runner):
self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
+ return log_dict
diff --git a/mmcv/runner/hooks/logger/wandb.py b/mmcv/runner/hooks/logger/wandb.py
index 38b597ae03..81220e644c 100644
--- a/mmcv/runner/hooks/logger/wandb.py
+++ b/mmcv/runner/hooks/logger/wandb.py
@@ -11,7 +11,7 @@ def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
- reset_flag=True,
+ reset_flag=False,
commit=True,
by_epoch=True,
with_step=True):
@@ -32,6 +32,7 @@ def import_wandb(self):
@master_only
def before_run(self, runner):
+ super(WandbLoggerHook, self).before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py
index 9ac00328bd..917c58c9bc 100644
--- a/mmcv/runner/hooks/lr_updater.py
+++ b/mmcv/runner/hooks/lr_updater.py
@@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from math import cos, pi
+from typing import Optional
import mmcv
from .hook import HOOKS, Hook
@@ -361,7 +362,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
Implement the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf
- Different from the original paper, we use cosine anealing rather than
+ Different from the original paper, we use cosine annealing rather than
triangular policy inside a cycle. This improves the performance in the
3D detection area.
@@ -614,3 +615,223 @@ def format_param(name, optim, param):
if name not in param:
raise KeyError(f'{name} is not found in {param.keys()}')
return param[name]
+
+
+@HOOKS.register_module()
+class ReduceLrUpdateHook(LrUpdaterHook):
+ """ReduceLROnPlateau Scheduler.
+
+ Reduce learning rate when a metric has stopped improving. This scheduler
+ reads a metrics quantity and if no improvement is seen for a 'patience'
+ number of epochs, the learning rate is reduced.
+
+ Args:
+ periods (list[int]): Periods that taking the metric value in count.
+ val_metric (str, optional): Metrics to be evaluated. If val_metric is
+ None, the metrics will be loss value. Default: None.
+ mode (str, optional): One of `min`, `max`. In `min` mode, lr will
+ be reduced when the quantity monitored has stopped
+ decreasing; in `max` mode it will be reduced when the
+ quantity monitored has stopped increasing. Default: 'min'.
+ factor (float, optional): Factor by which the learning rate will be
+ reduced. new_lr = lr * factor. Default: 0.1.
+ patience (int, optional): Number of epochs with no improvement after
+ which learning rate will be reduced. For example, if
+ `patience = 2`, then we will ignore the first 2 epochs
+ with no improvement, and will only decrease the LR after the
+ 3rd epoch if the loss still hasn't improved then.
+ Default: 10.
+ threshold (float, optional): Threshold for measuring the new optimum,
+ to only focus on significant changes. Default: 1e-4.
+ threshold_mode (str, optional): One of `rel`, `abs`. In `rel` mode,
+ dynamic_threshold = best * ( 1 + threshold ) in 'max'
+ mode or best * ( 1 - threshold ) in `min` mode.
+ In `abs` mode, dynamic_threshold = best + threshold in
+ `max` mode or best - threshold in `min` mode. Default: 'rel'.
+ cooldown (int, optional): Number of epochs to wait before resuming
+ normal operation after lr has been reduced. Default: 0.
+ min_lr (float, optional): Minimum LR value to keep. If LR after decay
+ is lower than `min_lr`, it will be clipped to this value.
+ Default: 0.
+ eps (float, optional): Minimal decay applied to lr. If the difference
+ between new and old lr is smaller than eps, the update is
+ ignored. Default: 1e-8.
+ """
+
+ def __init__(self,
+ periods: list,
+ val_metric: Optional[str] = None,
+ mode: str = 'min',
+ factor: float = 0.1,
+ patience: int = 10,
+ threshold: float = 1e-4,
+ threshold_mode: str = 'rel',
+ cooldown: int = 0,
+ min_lr: float = 0.,
+ eps: float = 1e-8,
+ **kwargs):
+ assert isinstance(periods, list), '"periods" must be a list'
+ assert mmcv.is_list_of(periods, int) and all([s >= 0 for s in periods])
+ self.periods = periods
+ self.val_metric = val_metric
+
+ if mode not in ['min', 'max']:
+ raise ValueError(
+ 'mode must be one of "min" or "max", instead got {mode}')
+ self.mode = mode
+
+ if factor >= 1.0:
+ raise ValueError('Factor should be < 1.0')
+ self.factor = factor
+
+ self.patience = patience
+ self.threshold = threshold
+
+ if threshold_mode not in ['rel', 'abs']:
+ raise ValueError('thresh_mode must be one of "rel" or "abs",\
+ instead got {threshold_mode}')
+ self.threshold_mode = threshold_mode
+
+ self.cooldown = cooldown
+ self.cooldown_counter = 0
+ self.best = None
+ self.num_bad_epochs = None
+ self.mode_worse = None # the worse value for the chosen mode
+ self.min_lr = min_lr
+ self.eps = eps
+ self.last_epoch = 0
+ self._init_is_better(self.mode)
+ self._reset()
+ super(ReduceLrUpdateHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, regular_lr):
+ if self.num_bad_epochs > self.patience:
+ self.cooldown_counter = self.cooldown
+ self.num_bad_epochs = 0
+ if regular_lr - regular_lr * self.factor > self.eps:
+ new_lr = max(regular_lr * self.factor, self.min_lr)
+ else:
+ new_lr = regular_lr
+ return new_lr
+ else:
+ return regular_lr
+
+ def get_regular_lr(self, runner):
+ if not self.regular_lr:
+ self.regular_lr = self.base_lr
+ if isinstance(runner.optimizer, dict):
+ lr_groups = {}
+ for k in runner.optimizer.keys():
+ _lr_group = [
+ self.get_lr(runner, _regular_lr)
+ for _regular_lr in self.regular_lr[k]
+ ]
+ lr_groups.update({k: _lr_group})
+ return lr_groups
+ else:
+ return [
+ self.get_lr(runner, _regular_lr)
+ for _regular_lr in self.regular_lr
+ ]
+
+ def _init_is_better(self, mode):
+ if mode == 'min':
+ self.mode_worse = float('inf')
+ else:
+ self.mode_worse = float('-inf')
+
+ def _reset(self):
+ self.best = self.mode_worse
+ self.cooldown_counter = 0
+ self.num_bad_epochs = 0
+
+ def is_better(self, a, best):
+ if self.mode == 'min' and self.threshold_mode == 'rel':
+ rel_epsilon = 1. - self.threshold
+ return a < best * rel_epsilon
+ elif self.mode == 'min' and self.threshold_mode == 'abs':
+ return a < best - self.threshold
+ elif self.mode == 'max' and self.threshold_mode == 'rel':
+ rel_epsilon = 1. + self.threshold
+ return a > best * rel_epsilon
+ else:
+ return a > best + self.threshold
+
+ @property
+ def in_cooldown(self):
+ return self.cooldown_counter > 0
+
+ def after_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ cur_epoch = runner.epoch
+ if self.warmup is not None and self.warmup_by_epoch:
+ if cur_epoch <= self.warmup_epochs:
+ return
+ if cur_epoch in self.periods and self.val_metric is None:
+ current = runner.outputs['loss']
+ if self.is_better(current, self.best):
+ self.best = current
+ self.num_bad_epochs = 0
+ else:
+ self.num_bad_epochs += 1
+
+ if self.in_cooldown:
+ self.cooldown_counter -= 1
+ self.num_bad_epochs = 0
+ print('epoch--', cur_epoch, ' lr:', self.regular_lr)
+
+ def after_train_iter(self, runner):
+ if self.by_epoch:
+ return
+ cur_iter = runner.iter
+ if self.warmup_epochs is not None and cur_iter <= self.warmup_iters:
+ return
+ if cur_iter in self.periods and self.val_metric is None:
+ current = runner.outputs['loss']
+ if self.is_better(current, self.best):
+ self.best = current
+ self.num_bad_epochs = 0
+ else:
+ self.num_bad_epochs += 1
+
+ if self.in_cooldown:
+ self.cooldown_counter -= 1
+ self.num_bad_epochs = 0
+
+ def after_val_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ cur_epoch = runner.epoch
+ if self.warmup is not None and self.warmup_by_epoch:
+ if cur_epoch <= self.warmup_epochs:
+ return
+ if cur_epoch in self.periods and self.val_metric is not None:
+ current = runner.outputs[self.val_metric]
+ if self.is_better(current, self.best):
+ self.best = current
+ self.num_bad_epochs = 0
+ else:
+ self.num_bad_epochs += 1
+
+ if self.in_cooldown:
+ self.cooldown_counter -= 1
+ self.num_bad_epochs = 0
+
+ def after_val_iter(self, runner):
+ if self.by_epoch:
+ return
+ cur_iter = runner.iter
+ if self.warmup_epochs is not None and cur_iter <= self.warmup_iters:
+ return
+ if cur_iter in self.periods and self.val_metric is not None:
+ current = runner.outputs[self.val_metric]
+ if self.is_better(current, self.best):
+ self.best = current
+ self.num_bad_epochs = 0
+ else:
+ self.num_bad_epochs += 1
+
+ if self.in_cooldown:
+ self.cooldown_counter -= 1
+ self.num_bad_epochs = 0
diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py
index ca21c703b4..a2f8114a7b 100644
--- a/mmcv/runner/hooks/optimizer.py
+++ b/mmcv/runner/hooks/optimizer.py
@@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
from collections import defaultdict
+from distutils.version import LooseVersion
from itertools import chain
from torch.nn.utils import clip_grad
@@ -42,7 +43,8 @@ def after_train_iter(self, runner):
runner.optimizer.step()
-if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
+if (TORCH_VERSION != 'parrots'
+ and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
@@ -59,7 +61,7 @@ class Fp16OptimizerHook(OptimizerHook):
It can also be a dict containing arguments of GradScalar.
Defaults to 512. For Pytorch >= 1.6, mmcv uses official
implementation of GradScaler. If you use a dict version of
- loss_scale to create GradScaler, plese refer to:
+ loss_scale to create GradScaler, please refer to:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
for the parameters.
@@ -70,7 +72,7 @@ class Fp16OptimizerHook(OptimizerHook):
... backoff_factor=0.5,
... growth_interval=2000
... )
- >>> optimizer = Fp16OptimizerHook(loss_scale=loss_scale)
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
"""
def __init__(self,
@@ -99,6 +101,10 @@ def before_run(self, runner):
"""Preparing steps before Mixed Precision Training."""
# wrap model mode to fp16
wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
@@ -125,6 +131,7 @@ def after_train_iter(self, runner):
2. Backward the loss to obtain the gradients.
3. Unscale the optimizer’s gradient tensors.
4. Call optimizer.step() and update scale factor.
+ 5. Save loss_scaler state_dict for resume purpose.
"""
# clear grads of last iteration
runner.model.zero_grad()
@@ -142,6 +149,10 @@ def after_train_iter(self, runner):
# backward and update scaler
self.loss_scaler.step(runner.optimizer)
self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
else:
@HOOKS.register_module()
@@ -210,6 +221,10 @@ def before_run(self, runner):
runner.optimizer.state = state
# convert model to fp16
wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
@@ -236,6 +251,7 @@ def after_train_iter(self, runner):
3. Copy gradients from the model to the fp32 weight copy.
4. Scale the gradients back and update the fp32 weight copy.
5. Copy back the params from fp32 weight copy to the fp16 model.
+ 6. Save loss_scaler state_dict for resume purpose.
"""
# clear grads of last iteration
runner.model.zero_grad()
@@ -276,3 +292,7 @@ def after_train_iter(self, runner):
if has_overflow:
runner.logger.warning('Check overflow, downscale loss scale '
f'to {self.loss_scaler.cur_scale}')
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
diff --git a/mmcv/runner/hooks/profiler.py b/mmcv/runner/hooks/profiler.py
index 82aed120a8..6b60915a2f 100644
--- a/mmcv/runner/hooks/profiler.py
+++ b/mmcv/runner/hooks/profiler.py
@@ -10,7 +10,7 @@
@HOOKS.register_module()
class ProfilerHook(Hook):
- """Profiler to analyze perfromance during training.
+ """Profiler to analyze performance during training.
PyTorch Profiler is a tool that allows the collection of the performance
metrics during the training. More details on Profiler can be found at
@@ -67,7 +67,7 @@ def __init__(self,
from torch import profiler # torch version >= 1.8.1
except ImportError:
raise ImportError('profiler is the new feature of torch1.8.1, '
- f'but your verison is {torch.__version__}')
+ f'but your version is {torch.__version__}')
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
self.by_epoch = by_epoch
@@ -120,10 +120,10 @@ def before_run(self, runner):
trace_type = trace_cfg.pop('type') # log_trace handler
if trace_type == 'log_trace':
- def _log_hanlder(prof):
+ def _log_handler(prof):
print(prof.key_averages().table(**trace_cfg))
- _on_trace_ready = _log_hanlder
+ _on_trace_ready = _log_handler
elif trace_type == 'tb_trace': # tensorboard_trace handler
try:
import torch_tb_profiler # noqa: F401
diff --git a/mmcv/runner/iter_based_runner.py b/mmcv/runner/iter_based_runner.py
index 75133d5ec4..b35d1823e2 100644
--- a/mmcv/runner/iter_based_runner.py
+++ b/mmcv/runner/iter_based_runner.py
@@ -108,6 +108,8 @@ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
@@ -193,14 +195,17 @@ def save_checkpoint(self,
latest checkpoint file. Defaults to True.
"""
if meta is None:
- meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
- elif isinstance(meta, dict):
- meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
- else:
+ meta = {}
+ elif not isinstance(meta, dict):
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
diff --git a/mmcv/runner/optimizer/default_constructor.py b/mmcv/runner/optimizer/default_constructor.py
index 477bf07fa4..6a455ff0a0 100644
--- a/mmcv/runner/optimizer/default_constructor.py
+++ b/mmcv/runner/optimizer/default_constructor.py
@@ -51,7 +51,7 @@ class DefaultOptimizerConstructor:
``dcn_offset_lr_mult``. If you wish to apply both of them to the
offset layer in deformable convs, set ``dcn_offset_lr_mult``
to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
- 2. If the option ``dcn_offset_lr_mult`` is used, the construtor will
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
apply it to all the DCN layers in the model. So be carefull when
the model contains multiple DCN layers in places other than
backbone.
diff --git a/mmcv/runner/priority.py b/mmcv/runner/priority.py
index b58c67e313..4a9383aa4e 100644
--- a/mmcv/runner/priority.py
+++ b/mmcv/runner/priority.py
@@ -5,29 +5,35 @@
class Priority(Enum):
"""Hook priority levels.
- +------------+------------+
- | Level | Value |
- +============+============+
- | HIGHEST | 0 |
- +------------+------------+
- | VERY_HIGH | 10 |
- +------------+------------+
- | HIGH | 30 |
- +------------+------------+
- | NORMAL | 50 |
- +------------+------------+
- | LOW | 70 |
- +------------+------------+
- | VERY_LOW | 90 |
- +------------+------------+
- | LOWEST | 100 |
- +------------+------------+
+ +--------------+------------+
+ | Level | Value |
+ +==============+============+
+ | HIGHEST | 0 |
+ +--------------+------------+
+ | VERY_HIGH | 10 |
+ +--------------+------------+
+ | HIGH | 30 |
+ +--------------+------------+
+ | ABOVE_NORMAL | 40 |
+ +--------------+------------+
+ | NORMAL | 50 |
+ +--------------+------------+
+ | BELOW_NORMAL | 60 |
+ +--------------+------------+
+ | LOW | 70 |
+ +--------------+------------+
+ | VERY_LOW | 90 |
+ +--------------+------------+
+ | LOWEST | 100 |
+ +--------------+------------+
"""
HIGHEST = 0
VERY_HIGH = 10
HIGH = 30
+ ABOVE_NORMAL = 40
NORMAL = 50
+ BELOW_NORMAL = 60
LOW = 70
VERY_LOW = 90
LOWEST = 100
diff --git a/mmcv/tensorrt/__init__.py b/mmcv/tensorrt/__init__.py
index 39a2eba6ea..0a245c058c 100644
--- a/mmcv/tensorrt/__init__.py
+++ b/mmcv/tensorrt/__init__.py
@@ -1,12 +1,29 @@
# flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
-from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt,
- save_trt_engine)
+from .preprocess import preprocess_onnx
-# load tensorrt plugin lib
-load_tensorrt_plugin()
-__all__ = [
- 'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
- 'is_tensorrt_plugin_loaded'
-]
+def is_tensorrt_available():
+ try:
+ import tensorrt
+ del tensorrt
+ return True
+ except ModuleNotFoundError:
+ return False
+
+
+__all__ = []
+
+if is_tensorrt_available():
+ from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine,
+ onnx2trt, save_trt_engine)
+
+ # load tensorrt plugin lib
+ load_tensorrt_plugin()
+
+ __all__.append([
+ 'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
+ 'TRTWrapper'
+ ])
+
+__all__.append(['is_tensorrt_plugin_loaded', 'preprocess_onnx'])
diff --git a/mmcv/tensorrt/preprocess.py b/mmcv/tensorrt/preprocess.py
new file mode 100644
index 0000000000..d07c67fc99
--- /dev/null
+++ b/mmcv/tensorrt/preprocess.py
@@ -0,0 +1,120 @@
+import numpy as np
+import onnx
+
+
+def preprocess_onnx(onnx_model):
+ """Modify onnx model to match with TensorRT plugins in mmcv.
+
+ There are some conflict between onnx node definition and TensorRT limit.
+ This function perform preprocess on the onnx model to solve the conflicts.
+ For example, onnx `attribute` is loaded in TensorRT on host and onnx
+ `input` is loaded on device. The shape inference is performed on host, so
+ any `input` related to shape (such as `max_output_boxes_per_class` in
+ NonMaxSuppression) should be transformed to `attribute` before conversion.
+
+ Arguments:
+ onnx_model (onnx.ModelProto): Input onnx model.
+
+ Returns:
+ onnx.ModelProto: Modified onnx model.
+ """
+ graph = onnx_model.graph
+ nodes = graph.node
+ initializers = graph.initializer
+ node_dict = {}
+ for node in nodes:
+ node_outputs = node.output
+ for output in node_outputs:
+ if len(output) > 0:
+ node_dict[output] = node
+
+ init_dict = {_.name: _ for _ in initializers}
+
+ nodes_name_to_remove = set()
+
+ def is_node_without_output(name):
+ for node_name, node in node_dict.items():
+ if node_name not in nodes_name_to_remove:
+ if name in node.input:
+ return False
+ return True
+
+ def mark_nodes_to_remove(name):
+ node = node_dict[name]
+ nodes_name_to_remove.add(name)
+ for input_node_name in node.input:
+ if is_node_without_output(input_node_name):
+ mark_nodes_to_remove(input_node_name)
+
+ def parse_data(name, typ, default_value=0):
+ if name in node_dict:
+ node = node_dict[name]
+ if node.op_type == 'Constant':
+ raw_data = node.attribute[0].t.raw_data
+ else:
+ mark_nodes_to_remove(name)
+ return default_value
+ elif name in init_dict:
+ raw_data = init_dict[name].raw_data
+ else:
+ raise ValueError(f'{name} not found in node or initilizer.')
+ return np.frombuffer(raw_data, typ).item()
+
+ nrof_node = len(nodes)
+ for idx in range(nrof_node):
+ node = nodes[idx]
+ node_attributes = node.attribute
+ node_inputs = node.input
+ node_outputs = node.output
+ node_name = node.name
+ # process NonMaxSuppression node
+ if node.op_type == 'NonMaxSuppression':
+ center_point_box = 0
+ max_output_boxes_per_class = 1000000
+ iou_threshold = 0.3
+ score_threshold = 0.0
+ offset = 0
+ for attribute in node_attributes:
+ if attribute.name == 'center_point_box':
+ center_point_box = attribute.i
+ elif attribute.name == 'offset':
+ offset = attribute.i
+
+ if len(node_inputs) >= 3:
+ max_output_boxes_per_class = parse_data(
+ node_inputs[2], np.int64, max_output_boxes_per_class)
+ mark_nodes_to_remove(node_inputs[2])
+
+ if len(node_inputs) >= 4:
+ iou_threshold = parse_data(node_inputs[3], np.float32,
+ iou_threshold)
+ mark_nodes_to_remove(node_inputs[3])
+
+ if len(node_inputs) >= 5:
+ score_threshold = parse_data(node_inputs[4], np.float32)
+ mark_nodes_to_remove(node_inputs[4])
+
+ new_node = onnx.helper.make_node(
+ 'NonMaxSuppression',
+ node_inputs[:2],
+ node_outputs,
+ name=node_name,
+ center_point_box=center_point_box,
+ max_output_boxes_per_class=max_output_boxes_per_class,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ offset=offset)
+
+ for output in node_outputs:
+ if output in node_dict:
+ node_dict[output] = new_node
+ nodes.insert(idx, new_node)
+ nodes.remove(node)
+ elif node.op_type == 'InstanceNormalization':
+ # directly change op name
+ node.op_type = 'MMCVInstanceNormalization'
+
+ for node_name in nodes_name_to_remove:
+ nodes.remove(node_dict[node_name])
+
+ return onnx_model
diff --git a/mmcv/tensorrt/tensorrt_utils.py b/mmcv/tensorrt/tensorrt_utils.py
index 5966881df2..a67aa6e32d 100644
--- a/mmcv/tensorrt/tensorrt_utils.py
+++ b/mmcv/tensorrt/tensorrt_utils.py
@@ -1,96 +1,10 @@
-import numpy as np
+import warnings
+
import onnx
import tensorrt as trt
import torch
-
-def preprocess_onnx(onnx_model):
- """Modify onnx model to match with TensorRT plugins in mmcv.
-
- There are some conflict between onnx node definition and TensorRT limit.
- This function perform preprocess on the onnx model to solve the conflicts.
- For example, onnx `attribute` is loaded in TensorRT on host and onnx
- `input` is loaded on device. The shape inference is performed on host, so
- any `input` related to shape (such as `max_output_boxes_per_class` in
- NonMaxSuppression) should be transformed to `attribute` before conversion.
-
- Arguments:
- onnx_model (onnx.ModelProto): Input onnx model.
-
- Returns:
- onnx.ModelProto: Modified onnx model.
- """
- graph = onnx_model.graph
- nodes = graph.node
- initializers = graph.initializer
- node_dict = {}
- for node in nodes:
- node_outputs = node.output
- for output in node_outputs:
- if len(output) > 0:
- node_dict[output] = node
-
- init_dict = {_.name: _ for _ in initializers}
-
- def parse_data(name, typ):
- if name in node_dict:
- const_node = node_dict[name]
- assert const_node.op_type == 'Constant'
- raw_data = const_node.attribute[0].t.raw_data
- elif name in init_dict:
- raw_data = init_dict[name].raw_data
- else:
- raise ValueError(f'{name} not found in node or initilizer.')
- return np.frombuffer(raw_data, typ).item()
-
- nrof_node = len(nodes)
- for idx in range(nrof_node):
- node = nodes[idx]
- node_attributes = node.attribute
- node_inputs = node.input
- node_outputs = node.output
- node_name = node.name
- # process NonMaxSuppression node
- if node.op_type == 'NonMaxSuppression':
- center_point_box = 0
- max_output_boxes_per_class = 1000000
- iou_threshold = 0.3
- score_threshold = 0.0
- offset = 0
- for attribute in node_attributes:
- if attribute.name == 'center_point_box':
- center_point_box = attribute.i
- elif attribute.name == 'offset':
- offset = attribute.i
-
- if len(node_inputs) >= 3:
- max_output_boxes_per_class = parse_data(
- node_inputs[2], np.int64)
-
- if len(node_inputs) >= 4:
- iou_threshold = parse_data(node_inputs[3], np.float32)
-
- if len(node_inputs) >= 5:
- score_threshold = parse_data(node_inputs[4], np.float32)
-
- new_node = onnx.helper.make_node(
- 'NonMaxSuppression',
- node_inputs[:2],
- node_outputs,
- name=node_name,
- center_point_box=center_point_box,
- max_output_boxes_per_class=max_output_boxes_per_class,
- iou_threshold=iou_threshold,
- score_threshold=score_threshold,
- offset=offset)
-
- for output in node_outputs:
- if output in node_dict:
- node_dict[output] = new_node
- nodes.insert(idx, new_node)
- nodes.remove(node)
-
- return onnx_model
+from .preprocess import preprocess_onnx
def onnx2trt(onnx_model,
@@ -225,8 +139,8 @@ def torch_device_from_trt(device):
return TypeError('%s is not supported by torch' % device)
-class TRTWraper(torch.nn.Module):
- """TensorRT engine Wraper.
+class TRTWrapper(torch.nn.Module):
+ """TensorRT engine Wrapper.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
@@ -238,8 +152,8 @@ class TRTWraper(torch.nn.Module):
output_names should be the same as onnx model.
"""
- def __init__(self, engine, input_names, output_names):
- super(TRTWraper, self).__init__()
+ def __init__(self, engine, input_names=None, output_names=None):
+ super(TRTWrapper, self).__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)
@@ -247,9 +161,14 @@ def __init__(self, engine, input_names, output_names):
if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError('engine should be str or trt.ICudaEngine')
- self._register_state_dict_hook(TRTWraper._on_state_dict)
+ self._register_state_dict_hook(TRTWrapper._on_state_dict)
self.context = self.engine.create_execution_context()
+ # get input and output names from engine
+ if input_names is None or output_names is None:
+ names = [_ for _ in self.engine]
+ input_names = list(filter(self.engine.binding_is_input, names))
+ output_names = list(set(names) - set(input_names))
self.input_names = input_names
self.output_names = output_names
@@ -305,3 +224,11 @@ def forward(self, inputs):
torch.cuda.current_stream().cuda_stream)
return outputs
+
+
+class TRTWraper(TRTWrapper):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn('TRTWraper will be deprecated in'
+ ' future. Please use TRTWrapper instead')
diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py
index ba2a2c9e94..6ca3452409 100644
--- a/mmcv/utils/__init__.py
+++ b/mmcv/utils/__init__.py
@@ -2,9 +2,11 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
- import_modules_from_strings, is_list_of, is_seq_of, is_str,
- is_tuple_of, iter_cast, list_cast, requires_executable,
- requires_package, slice_list, tuple_cast)
+ import_modules_from_strings, is_list_of,
+ is_method_overridden, is_seq_of, is_str, is_tuple_of,
+ iter_cast, list_cast, requires_executable, requires_package,
+ slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
+ to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
@@ -29,17 +31,19 @@
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
- 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script'
+ 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
+ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
+ 'is_method_overridden'
]
else:
from .env import collect_env
from .logging import get_logger, print_log
+ from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import (
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
- from .parrots_jit import jit, skip_no_elena
from .registry import Registry, build_from_cfg
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
@@ -58,5 +62,6 @@
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
- 'assert_params_all_zeros', 'check_python_script'
+ 'assert_params_all_zeros', 'check_python_script',
+ 'is_method_overridden'
]
diff --git a/mmcv/utils/config.py b/mmcv/utils/config.py
index f48778de97..56c7d9bd93 100644
--- a/mmcv/utils/config.py
+++ b/mmcv/utils/config.py
@@ -1,10 +1,13 @@
# Copyright (c) Open-MMLab. All rights reserved.
import ast
+import copy
+import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
+import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
@@ -120,6 +123,57 @@ def _substitute_predefined_vars(filename, temp_config_name):
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
+ base_var_dict[randstr] = base_var
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, 'w') as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split('.'):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(
+ v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg)
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split('.'):
+ new_v = new_v[new_k]
+ cfg = new_v
+
+ return cfg
+
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
@@ -140,6 +194,9 @@ def _file2dict(filename, use_predefined_variables=True):
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
@@ -184,6 +241,10 @@ def _file2dict(filename, use_predefined_variables=True):
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)
+ # Subtitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+ base_cfg_dict)
+
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
@@ -275,11 +336,13 @@ def fromstring(cfg_str, file_format):
# check if users specify a wrong suffix for python
warnings.warn(
'Please check "file_format", the file format may be .py')
-
- with tempfile.NamedTemporaryFile('w', suffix=file_format) as temp_file:
+ with tempfile.NamedTemporaryFile(
+ 'w', suffix=file_format, delete=False) as temp_file:
temp_file.write(cfg_str)
- temp_file.flush()
- cfg = Config.fromfile(temp_file.name)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
return cfg
@staticmethod
@@ -555,7 +618,7 @@ def _parse_iterable(val):
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
- [(1, 2, 3), ['a', 'b], 'c']
+ [(1, 2, 3), ['a', 'b'], 'c']
"""
def find_next_comma(string):
diff --git a/mmcv/utils/ext_loader.py b/mmcv/utils/ext_loader.py
index 826e70bb16..2a3c223838 100644
--- a/mmcv/utils/ext_loader.py
+++ b/mmcv/utils/ext_loader.py
@@ -1,6 +1,7 @@
import importlib
import os
import pkgutil
+import warnings
from collections import namedtuple
import torch
@@ -14,24 +15,51 @@ def load_ext(name, funcs):
return ext
else:
from parrots import extension
+ from parrots.base import ParrotsException
has_return_value_ops = [
- 'nms', 'softnms', 'nms_match', 'nms_rotated', 'top_pool_forward',
- 'top_pool_backward', 'bottom_pool_forward', 'bottom_pool_backward',
- 'left_pool_forward', 'left_pool_backward', 'right_pool_forward',
- 'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d'
+ 'nms',
+ 'softnms',
+ 'nms_match',
+ 'nms_rotated',
+ 'top_pool_forward',
+ 'top_pool_backward',
+ 'bottom_pool_forward',
+ 'bottom_pool_backward',
+ 'left_pool_forward',
+ 'left_pool_backward',
+ 'right_pool_forward',
+ 'right_pool_backward',
+ 'fused_bias_leakyrelu',
+ 'upfirdn2d',
+ 'ms_deform_attn_forward',
]
+ def get_fake_func(name, e):
+
+ def fake_func(*args, **kwargs):
+ warnings.warn(f'{name} is not supported in parrots now')
+ raise e
+
+ return fake_func
+
def load_ext(name, funcs):
ExtModule = namedtuple('ExtModule', funcs)
ext_list = []
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
for fun in funcs:
- if fun in has_return_value_ops:
- ext_list.append(extension.load(fun, name, lib_dir=lib_root).op)
+ try:
+ ext_fun = extension.load(fun, name, lib_dir=lib_root)
+ except ParrotsException as e:
+ if 'No element registered' not in e.message:
+ warnings.warn(e.message)
+ ext_fun = get_fake_func(fun, e)
+ ext_list.append(ext_fun)
else:
- ext_list.append(
- extension.load(fun, name, lib_dir=lib_root).op_)
+ if fun in has_return_value_ops:
+ ext_list.append(ext_fun.op)
+ else:
+ ext_list.append(ext_fun.op_)
return ExtModule(*ext_list)
diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py
index da70738b80..dee1fa03c9 100644
--- a/mmcv/utils/misc.py
+++ b/mmcv/utils/misc.py
@@ -1,4 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved.
+import collections.abc
import functools
import itertools
import subprocess
@@ -6,6 +7,25 @@
from collections import abc
from importlib import import_module
from inspect import getfullargspec
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
def is_str(x):
@@ -266,7 +286,7 @@ def requires_executable(prerequisites):
def deprecated_api_warning(name_dict, cls_name=None):
- """A decorator to check if some argments are deprecate and try to replace
+ """A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
@@ -313,3 +333,22 @@ def new_func(*args, **kwargs):
return new_func
return api_warning_wrapper
+
+
+def is_method_overridden(method, base_class, derived_class):
+ """Check if a method of base class is overridden in derived class.
+
+ Args:
+ method (str): the method name to check.
+ base_class (type): the class of the base class.
+ derived_class (type | Any): the class or instance of the derived class.
+ """
+ assert isinstance(base_class, type), \
+ "base_class doesn't accept instance, Please pass class instead."
+
+ if not isinstance(derived_class, type):
+ derived_class = derived_class.__class__
+
+ base_method = getattr(base_class, method)
+ derived_method = getattr(derived_class, method)
+ return derived_method != base_method
diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py
index 25761be835..ccc22b09e1 100644
--- a/mmcv/utils/parrots_wrapper.py
+++ b/mmcv/utils/parrots_wrapper.py
@@ -82,10 +82,6 @@ def _get_norm():
class SyncBatchNorm(SyncBatchNorm_):
- def _specify_ddp_gpu_num(self, gpu_size):
- if TORCH_VERSION != 'parrots':
- super()._specify_ddp_gpu_num(gpu_size)
-
def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots':
if input.dim() < 2:
diff --git a/mmcv/utils/path.py b/mmcv/utils/path.py
index aed078fe98..3a4d038445 100644
--- a/mmcv/utils/path.py
+++ b/mmcv/utils/path.py
@@ -63,16 +63,12 @@ def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root)
- if suffix is None:
+ if suffix is None or rel_path.endswith(suffix):
yield rel_path
- elif rel_path.endswith(suffix):
- yield rel_path
- else:
- if recursive:
- yield from _scandir(
- entry.path, suffix=suffix, recursive=recursive)
- else:
- continue
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(
+ entry.path, suffix=suffix, recursive=recursive)
return _scandir(dir_path, suffix=suffix, recursive=recursive)
diff --git a/mmcv/version.py b/mmcv/version.py
index 8426b0856f..921a14cf4a 100644
--- a/mmcv/version.py
+++ b/mmcv/version.py
@@ -1,6 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
-__version__ = '1.3.4'
+__version__ = '1.3.8'
def parse_version_info(version_str: str) -> tuple:
diff --git a/mmcv/visualization/image.py b/mmcv/visualization/image.py
index 4d0a2f1ea1..9621d7f47b 100644
--- a/mmcv/visualization/image.py
+++ b/mmcv/visualization/image.py
@@ -15,7 +15,7 @@ def imshow(img, win_name='', wait_time=0):
wait_time (int): Value of waitKey param.
"""
cv2.imshow(win_name, imread(img))
- if wait_time == 0: # prevent from hangning if windows was closed
+ if wait_time == 0: # prevent from hanging if windows was closed
while True:
ret = cv2.waitKey(1)
diff --git a/requirements/docs.txt b/requirements/docs.txt
index e14f32b690..962eec76b0 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -1,3 +1,6 @@
m2r
+opencv-python
+sphinx
sphinx_markdown_tables
+sphinx_rtd_theme
torch
diff --git a/requirements/test.txt b/requirements/test.txt
index fe41ebe185..ab4ecbd5c1 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -5,4 +5,5 @@ onnxoptimizer
onnxruntime==1.4.0
pytest
PyTurboJPEG
+scipy
tiffile
diff --git a/setup.cfg b/setup.cfg
index 25825f09aa..fbd78ef0e6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
known_first_party = mmcv
-known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
+known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
diff --git a/tests/data/config/t.json b/tests/data/config/t.json
new file mode 100644
index 0000000000..8f7b9b4a17
--- /dev/null
+++ b/tests/data/config/t.json
@@ -0,0 +1,13 @@
+{
+ "_base_": [
+ "./l1.py",
+ "./l2.yaml",
+ "./l3.json",
+ "./l4.py"
+ ],
+ "item3": false,
+ "item4": "test",
+ "item8": "{{fileBasename}}",
+ "item9": {{ _base_.item2 }},
+ "item10": {{ _base_.item7.b.c }}
+}
diff --git a/tests/data/config/t.py b/tests/data/config/t.py
new file mode 100644
index 0000000000..9f085ae675
--- /dev/null
+++ b/tests/data/config/t.py
@@ -0,0 +1,6 @@
+_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py']
+item3 = False
+item4 = 'test'
+item8 = '{{fileBasename}}'
+item9 = {{ _base_.item2 }}
+item10 = {{ _base_.item7.b.c }}
diff --git a/tests/data/config/t.yaml b/tests/data/config/t.yaml
new file mode 100644
index 0000000000..ab42859ec9
--- /dev/null
+++ b/tests/data/config/t.yaml
@@ -0,0 +1,6 @@
+_base_ : ['./l1.py', './l2.yaml', './l3.json', './l4.py']
+item3 : False
+item4 : 'test'
+item8 : '{{fileBasename}}'
+item9 : {{ _base_.item2 }}
+item10 : {{ _base_.item7.b.c }}
diff --git a/tests/data/config/u.json b/tests/data/config/u.json
new file mode 100644
index 0000000000..f6a01e3c08
--- /dev/null
+++ b/tests/data/config/u.json
@@ -0,0 +1,26 @@
+{
+ "_base_": [
+ "./t.py"
+ ],
+ "base": "_base_.item8",
+ "item11": {{ _base_.item8 }},
+ "item12": {{ _base_.item9 }},
+ "item13": {{ _base_.item10 }},
+ "item14": {{ _base_.item1 }},
+ "item15": {
+ "a": {
+ "b": {{ _base_.item2 }}
+ },
+ "b": [
+ {{ _base_.item3 }}
+ ],
+ "c": [{{ _base_.item4 }}],
+ "d": [[
+ {
+ "e": {{ _base_.item5.a }}
+ }
+ ],
+ {{ _base_.item6 }}],
+ "e": {{ _base_.item1 }}
+ }
+}
diff --git a/tests/data/config/u.py b/tests/data/config/u.py
new file mode 100644
index 0000000000..bdd96a7e46
--- /dev/null
+++ b/tests/data/config/u.py
@@ -0,0 +1,13 @@
+_base_ = ['./t.py']
+base = '_base_.item8'
+item11 = {{ _base_.item8 }}
+item12 = {{ _base_.item9 }}
+item13 = {{ _base_.item10 }}
+item14 = {{ _base_.item1 }}
+item15 = dict(
+ a = dict( b = {{ _base_.item2 }} ),
+ b = [{{ _base_.item3 }}],
+ c = [{{ _base_.item4 }}],
+ d = [[dict(e = {{ _base_.item5.a }})],{{ _base_.item6 }}],
+ e = {{ _base_.item1 }}
+)
diff --git a/tests/data/config/u.yaml b/tests/data/config/u.yaml
new file mode 100644
index 0000000000..d201cb926d
--- /dev/null
+++ b/tests/data/config/u.yaml
@@ -0,0 +1,15 @@
+_base_: ["./t.py"]
+base: "_base_.item8"
+item11: {{ _base_.item8 }}
+item12: {{ _base_.item9 }}
+item13: {{ _base_.item10 }}
+item14: {{ _base_.item1 }}
+item15:
+ a:
+ b: {{ _base_.item2 }}
+ b: [{{ _base_.item3 }}]
+ c: [{{ _base_.item4 }}]
+ d:
+ - [e: {{ _base_.item5.a }}]
+ - {{ _base_.item6 }}
+ e: {{ _base_.item1 }}
diff --git a/tests/data/config/v.py b/tests/data/config/v.py
new file mode 100644
index 0000000000..3d2a1a436c
--- /dev/null
+++ b/tests/data/config/v.py
@@ -0,0 +1,11 @@
+_base_ = ['./u.py']
+item21 = {{ _base_.item11 }}
+item22 = item21
+item23 = {{ _base_.item10 }}
+item24 = item23
+item25 = dict(
+ a = dict( b = item24 ),
+ b = [item24],
+ c = [[dict(e = item22)],{{ _base_.item6 }}],
+ e = item21
+)
diff --git a/tests/data/for_scan/.file b/tests/data/for_scan/.file
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/test_cnn/test_generalized_attention.py b/tests/test_cnn/test_generalized_attention.py
index bec3288c82..27207c9241 100644
--- a/tests/test_cnn/test_generalized_attention.py
+++ b/tests/test_cnn/test_generalized_attention.py
@@ -60,3 +60,16 @@ def test_context_block():
assert gen_attention_block.kv_downsample is not None
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
+
+ # test fp16 with attention_type='1111'
+ if torch.cuda.is_available():
+ imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half)
+ gen_attention_block = GeneralizedAttention(
+ 16,
+ spatial_range=-1,
+ num_heads=8,
+ attention_type='1111',
+ kv_stride=2)
+ gen_attention_block.cuda().type(torch.half)
+ out = gen_attention_block(imgs)
+ assert out.shape == imgs.shape
diff --git a/tests/test_cnn/test_transformer.py b/tests/test_cnn/test_transformer.py
new file mode 100644
index 0000000000..a4a5f62e9c
--- /dev/null
+++ b/tests/test_cnn/test_transformer.py
@@ -0,0 +1,177 @@
+import pytest
+import torch
+
+from mmcv.cnn.bricks.drop import DropPath
+from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer,
+ MultiheadAttention,
+ TransformerLayerSequence)
+
+
+def test_multiheadattention():
+ MultiheadAttention(
+ embed_dims=5,
+ num_heads=5,
+ attn_drop=0,
+ proj_drop=0,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ batch_first=True)
+ batch_dim = 2
+ embed_dim = 5
+ num_query = 100
+ attn_batch_first = MultiheadAttention(
+ embed_dims=5,
+ num_heads=5,
+ attn_drop=0,
+ proj_drop=0,
+ dropout_layer=dict(type='DropPath', drop_prob=0.),
+ batch_first=True)
+
+ attn_query_first = MultiheadAttention(
+ embed_dims=5,
+ num_heads=5,
+ attn_drop=0,
+ proj_drop=0,
+ dropout_layer=dict(type='DropPath', drop_prob=0.),
+ batch_first=False)
+
+ param_dict = dict(attn_query_first.named_parameters())
+ for n, v in attn_batch_first.named_parameters():
+ param_dict[n].data = v.data
+
+ input_batch_first = torch.rand(batch_dim, num_query, embed_dim)
+ input_query_first = input_batch_first.transpose(0, 1)
+
+ assert torch.allclose(
+ attn_query_first(input_query_first).sum(),
+ attn_batch_first(input_batch_first).sum())
+
+ key_batch_first = torch.rand(batch_dim, num_query, embed_dim)
+ key_query_first = key_batch_first.transpose(0, 1)
+
+ assert torch.allclose(
+ attn_query_first(input_query_first, key_query_first).sum(),
+ attn_batch_first(input_batch_first, key_batch_first).sum())
+
+ identity = torch.ones_like(input_query_first)
+
+ # check deprecated arguments can be used normally
+
+ assert torch.allclose(
+ attn_query_first(
+ input_query_first, key_query_first, residual=identity).sum(),
+ attn_batch_first(input_batch_first, key_batch_first).sum() +
+ identity.sum() - input_batch_first.sum())
+
+ assert torch.allclose(
+ attn_query_first(
+ input_query_first, key_query_first, identity=identity).sum(),
+ attn_batch_first(input_batch_first, key_batch_first).sum() +
+ identity.sum() - input_batch_first.sum())
+
+ attn_query_first(
+ input_query_first, key_query_first, identity=identity).sum(),
+
+
+def test_ffn():
+ with pytest.raises(AssertionError):
+ # num_fcs should be no less than 2
+ FFN(num_fcs=1)
+ FFN(dropout=0, add_residual=True)
+ ffn = FFN(dropout=0, add_identity=True)
+
+ input_tensor = torch.rand(2, 20, 256)
+ input_tensor_nbc = input_tensor.transpose(0, 1)
+ assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())
+ residual = torch.rand_like(input_tensor)
+ torch.allclose(
+ ffn(input_tensor, residual=residual).sum(),
+ ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
+
+ torch.allclose(
+ ffn(input_tensor, identity=residual).sum(),
+ ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
+
+
+def test_basetransformerlayer():
+ attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
+ feedforward_channels = 2048
+ ffn_dropout = 0.1
+ operation_order = ('self_attn', 'norm', 'ffn', 'norm')
+
+ # test deprecated_args
+ baselayer = BaseTransformerLayer(
+ attn_cfgs=attn_cfgs,
+ feedforward_channels=feedforward_channels,
+ ffn_dropout=ffn_dropout,
+ operation_order=operation_order)
+ assert baselayer.batch_first is False
+ assert baselayer.ffns[0].feedforward_channels == feedforward_channels
+
+ attn_cfgs = dict(type='MultiheadAttention', num_heads=8, embed_dims=256),
+ feedforward_channels = 2048
+ ffn_dropout = 0.1
+ operation_order = ('self_attn', 'norm', 'ffn', 'norm')
+ baselayer = BaseTransformerLayer(
+ attn_cfgs=attn_cfgs,
+ feedforward_channels=feedforward_channels,
+ ffn_dropout=ffn_dropout,
+ operation_order=operation_order,
+ batch_first=True)
+ assert baselayer.attentions[0].batch_first
+ in_tensor = torch.rand(2, 10, 256)
+ baselayer(in_tensor)
+
+
+def test_transformerlayersequence():
+ squeue = TransformerLayerSequence(
+ num_layers=6,
+ transformerlayers=dict(
+ type='BaseTransformerLayer',
+ attn_cfgs=[
+ dict(
+ type='MultiheadAttention',
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ dict(type='MultiheadAttention', embed_dims=256, num_heads=4)
+ ],
+ feedforward_channels=1024,
+ ffn_dropout=0.1,
+ operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn',
+ 'norm')))
+ assert len(squeue.layers) == 6
+ assert squeue.pre_norm is False
+ with pytest.raises(AssertionError):
+ # if transformerlayers is a list, len(transformerlayers)
+ # should be equal to num_layers
+ TransformerLayerSequence(
+ num_layers=6,
+ transformerlayers=[
+ dict(
+ type='BaseTransformerLayer',
+ attn_cfgs=[
+ dict(
+ type='MultiheadAttention',
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1),
+ dict(type='MultiheadAttention', embed_dims=256)
+ ],
+ feedforward_channels=1024,
+ ffn_dropout=0.1,
+ operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
+ 'ffn', 'norm'))
+ ])
+
+
+def test_drop_path():
+ drop_path = DropPath(drop_prob=0)
+ test_in = torch.rand(2, 3, 4, 5)
+ assert test_in is drop_path(test_in)
+
+ drop_path = DropPath(drop_prob=0.1)
+ drop_path.training = False
+ test_in = torch.rand(2, 3, 4, 5)
+ assert test_in is drop_path(test_in)
+ drop_path.training = True
+ assert test_in is not drop_path(test_in)
diff --git a/tests/test_cnn/test_weight_init.py b/tests/test_cnn/test_weight_init.py
index 343079c45e..82ce6423bf 100644
--- a/tests/test_cnn/test_weight_init.py
+++ b/tests/test_cnn/test_weight_init.py
@@ -1,16 +1,18 @@
# Copyright (c) Open-MMLab. All rights reserved.
+import random
from tempfile import TemporaryDirectory
import numpy as np
import pytest
import torch
+from scipy import stats
from torch import nn
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
- PretrainedInit, UniformInit, XavierInit,
+ PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
- initialize, kaiming_init, normal_init, uniform_init,
- xavier_init)
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
def test_constant_init():
@@ -47,6 +49,35 @@ def test_normal_init():
# TODO: sanity check distribution, e.g. mean, std
+def test_trunc_normal_init():
+
+ def _random_float(a, b):
+ return (b - a) * random.random() + a
+
+ def _is_trunc_normal(tensor, mean, std, a, b):
+ # scipy's trunc norm is suited for data drawn from N(0, 1),
+ # so we need to transform our data to test it using scipy.
+ z_samples = (tensor.view(-1) - mean) / std
+ z_samples = z_samples.tolist()
+ a0 = (a - mean) / std
+ b0 = (b - mean) / std
+ p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
+ return p_value > 0.0001
+
+ conv_module = nn.Conv2d(3, 16, 3)
+ mean = _random_float(-3, 3)
+ std = _random_float(.01, 1)
+ a = _random_float(mean - 2 * std, mean)
+ b = _random_float(mean, mean + 2 * std)
+ trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
+ assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
+ assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
+
+ conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
+ trunc_normal_init(conv_module_no_bias)
+ # TODO: sanity check distribution, e.g. mean, std
+
+
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
@@ -103,6 +134,15 @@ def test_constaninit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
+ # test layer key with base class name
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
+ func = ConstantInit(val=4., bias=5., layer='_ConvNd')
+ func(model)
+ assert torch.all(model[0].weight == 4.)
+ assert torch.all(model[2].weight == 4.)
+ assert torch.all(model[0].bias == 5.)
+ assert torch.all(model[2].bias == 5.)
+
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
@@ -139,6 +179,22 @@ def test_xavierinit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
+ # test layer key with base class name
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
+ func = ConstantInit(val=4., bias=5., layer='_ConvNd')
+ func(model)
+ assert torch.all(model[0].weight == 4.)
+ assert torch.all(model[2].weight == 4.)
+ assert torch.all(model[0].bias == 5.)
+ assert torch.all(model[2].bias == 5.)
+
+ func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
+ func(model)
+ assert not torch.all(model[0].weight == 4.)
+ assert not torch.all(model[2].weight == 4.)
+ assert torch.all(model[0].bias == res)
+ assert torch.all(model[2].bias == res)
+
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
@@ -167,6 +223,54 @@ def test_normalinit():
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
+ # test layer key with base class name
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
+
+ func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
+ func(model)
+ assert model[0].weight.allclose(torch.tensor(300.))
+ assert model[2].weight.allclose(torch.tensor(300.))
+ assert torch.all(model[0].bias == res)
+ assert torch.all(model[2].bias == res)
+
+
+def test_truncnormalinit():
+ """test TruncNormalInit class."""
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
+
+ func = TruncNormalInit(
+ mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
+ func(model)
+ assert model[0].weight.allclose(torch.tensor(100.))
+ assert model[2].weight.allclose(torch.tensor(100.))
+ assert model[0].bias.allclose(torch.tensor(200.))
+ assert model[2].bias.allclose(torch.tensor(200.))
+
+ func = TruncNormalInit(
+ mean=300,
+ std=1e-5,
+ a=100,
+ b=400,
+ bias_prob=0.01,
+ layer=['Conv2d', 'Linear'])
+ res = bias_init_with_prob(0.01)
+ func(model)
+ assert model[0].weight.allclose(torch.tensor(300.))
+ assert model[2].weight.allclose(torch.tensor(300.))
+ assert model[0].bias.allclose(torch.tensor(res))
+ assert model[2].bias.allclose(torch.tensor(res))
+
+ # test layer key with base class name
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
+
+ func = TruncNormalInit(
+ mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
+ func(model)
+ assert model[0].weight.allclose(torch.tensor(300.))
+ assert model[2].weight.allclose(torch.tensor(300.))
+ assert torch.all(model[0].bias == res)
+ assert torch.all(model[2].bias == res)
+
def test_uniforminit():
""""test UniformInit class."""
@@ -187,6 +291,17 @@ def test_uniforminit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
+ # test layer key with base class name
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
+
+ func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
+ res = bias_init_with_prob(0.01)
+ func(model)
+ assert torch.all(model[0].weight == 100.)
+ assert torch.all(model[2].weight == 100.)
+ assert torch.all(model[0].bias == res)
+ assert torch.all(model[2].bias == res)
+
def test_kaiminginit():
"""test KaimingInit class."""
@@ -212,6 +327,29 @@ def test_kaiminginit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
+ # test layer key with base class name
+ model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
+ func = KaimingInit(bias=0.1, layer='_ConvNd')
+ func(model)
+ assert torch.all(model[0].bias == 0.1)
+ assert torch.all(model[2].bias == 0.1)
+
+ func = KaimingInit(a=100, bias=10, layer='_ConvNd')
+ constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
+ model.apply(constant_func)
+ assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
+ assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
+ assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
+ assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
+
+ func(model)
+ assert not torch.equal(model[0].weight,
+ torch.full(model[0].weight.shape, 0.))
+ assert not torch.equal(model[2].weight,
+ torch.full(model[2].weight.shape, 0.))
+ assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
+ assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
+
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py
index 326cfd2d2a..ffc933fec2 100644
--- a/tests/test_cnn/test_wrappers.py
+++ b/tests/test_cnn/test_wrappers.py
@@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature):
wrapper(x_empty)
-@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 8))
+@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10))
def test_nn_op_forward_called():
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
@@ -347,6 +347,20 @@ def test_nn_op_forward_called():
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
+ for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']:
+ with patch(f'torch.nn.{m}.forward') as nn_module_forward:
+ # randn input
+ x_empty = torch.randn(0, 3, 10, 10, 10)
+ wrapper = eval(m)(3, 2, 1)
+ wrapper(x_empty)
+ nn_module_forward.assert_called_with(x_empty)
+
+ # non-randn input
+ x_normal = torch.randn(1, 3, 10, 10, 10)
+ wrapper = eval(m)(3, 2, 1)
+ wrapper(x_normal)
+ nn_module_forward.assert_called_with(x_normal)
+
with patch('torch.nn.Linear.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3)
diff --git a/tests/test_image/test_geometric.py b/tests/test_image/test_geometric.py
index 56e9b9f938..1048ea5a4d 100644
--- a/tests/test_image/test_geometric.py
+++ b/tests/test_image/test_geometric.py
@@ -47,6 +47,55 @@ def test_imresize(self):
with pytest.raises(ValueError):
mmcv.imresize(self.img, (1000, 600), backend='not support')
+ def test_imresize_to_multiple(self):
+ # test size and keep_ratio = False
+ resized_img = mmcv.imresize_to_multiple(
+ self.img, divisor=16, size=(511, 513), keep_ratio=False)
+ assert resized_img.shape == (528, 512, 3)
+ resized_img = mmcv.imresize_to_multiple(
+ self.img, divisor=(16, 32), size=(511, 513), keep_ratio=False)
+ assert resized_img.shape == (544, 512, 3)
+
+ # test size, keep_ratio = True, and return_scale
+ resized_img, w_scale, h_scale = mmcv.imresize_to_multiple(
+ self.img,
+ divisor=16,
+ size=(1000, 600),
+ keep_ratio=True,
+ return_scale=True)
+ assert resized_img.shape == (
+ 608, 800, 3) and h_scale == 608 / 300 and w_scale == 800 / 400
+ resized_img, w_scale, h_scale = mmcv.imresize_to_multiple(
+ self.img,
+ divisor=(18, 16),
+ size=(1000, 600),
+ keep_ratio=True,
+ return_scale=True)
+ assert resized_img.shape == (
+ 608, 810, 3) and h_scale == 608 / 300 and w_scale == 810 / 400
+
+ # test scale_factor and return_scale
+ resized_img, w_scale, h_scale = mmcv.imresize_to_multiple(
+ self.img, divisor=16, scale_factor=2, return_scale=True)
+ assert resized_img.shape == (
+ 608, 800, 3) and h_scale == 608 / 300 and w_scale == 800 / 400
+ resized_img, w_scale, h_scale = mmcv.imresize_to_multiple(
+ self.img, divisor=16, scale_factor=(2, 3), return_scale=True)
+ assert resized_img.shape == (
+ 912, 800, 3) and h_scale == 912 / 300 and w_scale == 800 / 400
+ resized_img, w_scale, h_scale = mmcv.imresize_to_multiple(
+ self.img, divisor=(18, 16), scale_factor=(2, 3), return_scale=True)
+ assert resized_img.shape == (
+ 912, 810, 3) and h_scale == 912 / 300 and w_scale == 810 / 400
+
+ # one of size and scale_factor shuld be given
+ with pytest.raises(ValueError):
+ mmcv.imresize_to_multiple(
+ self.img, divisor=16, size=(1000, 600), scale_factor=2)
+ with pytest.raises(ValueError):
+ mmcv.imresize_to_multiple(
+ self.img, divisor=16, size=None, scale_factor=None)
+
def test_imresize_like(self):
a = np.zeros((100, 200, 3))
resized_img = mmcv.imresize_like(self.img, a)
diff --git a/tests/test_image/test_io.py b/tests/test_image/test_io.py
index 1658c4657d..869a9a7add 100644
--- a/tests/test_image/test_io.py
+++ b/tests/test_image/test_io.py
@@ -184,12 +184,30 @@ def test_imread(self):
# consistent exif behaviour
img_cv2_exif = mmcv.imread(self.exif_img_path)
img_pil_exif = mmcv.imread(self.exif_img_path, backend='pillow')
- assert img_cv2_exif.shape == img_pil_exif.shape
+ assert img_cv2_exif.shape == (400, 300, 3)
+ assert img_pil_exif.shape == (400, 300, 3)
img_cv2_exif_unchanged = mmcv.imread(
self.exif_img_path, flag='unchanged')
img_pil_exif_unchanged = mmcv.imread(
self.exif_img_path, backend='pillow', flag='unchanged')
- assert img_cv2_exif_unchanged.shape == img_pil_exif_unchanged.shape
+ assert img_cv2_exif_unchanged.shape == (300, 400, 3)
+ assert img_pil_exif_unchanged.shape == (300, 400, 3)
+ img_cv2_color_ignore_exif = mmcv.imread(
+ self.exif_img_path, flag='color_ignore_orientation')
+ img_pil_color_ignore_exif = mmcv.imread(
+ self.exif_img_path,
+ backend='pillow',
+ flag='color_ignore_orientation')
+ assert img_cv2_color_ignore_exif.shape == (300, 400, 3)
+ assert img_pil_color_ignore_exif.shape == (300, 400, 3)
+ img_cv2_grayscale_ignore_exif = mmcv.imread(
+ self.exif_img_path, flag='grayscale_ignore_orientation')
+ img_pil_grayscale_ignore_exif = mmcv.imread(
+ self.exif_img_path,
+ backend='pillow',
+ flag='grayscale_ignore_orientation')
+ assert img_cv2_grayscale_ignore_exif.shape == (300, 400)
+ assert img_pil_grayscale_ignore_exif.shape == (300, 400)
def test_imfrombytes(self):
# backend cv2, channel order: bgr
diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py
index f08bf69132..400864700d 100644
--- a/tests/test_load_model_zoo.py
+++ b/tests/test_load_model_zoo.py
@@ -11,6 +11,7 @@
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
+from mmcv.utils import TORCH_VERSION
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@@ -77,13 +78,23 @@ def load(filepath, map_location=None):
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
- assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
- '.pth'
+ if TORCH_VERSION < '1.9.0':
+ assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
+ '357.pth')
+ else:
+ # filename of checkpoint is renamed in torch1.9.0
+ assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
+ 'a61.pth')
# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
- assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
- '.pth'
+ if TORCH_VERSION < '1.9.0':
+ assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
+ '357.pth')
+ else:
+ # filename of checkpoint is renamed in torch1.9.0
+ assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
+ 'a61.pth')
# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
diff --git a/tests/test_ops/test_bilinear_grid_sample.py b/tests/test_ops/test_bilinear_grid_sample.py
new file mode 100644
index 0000000000..cf0bf437de
--- /dev/null
+++ b/tests/test_ops/test_bilinear_grid_sample.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TestBilinearGridSample(object):
+
+ def _test_bilinear_grid_sample(self,
+ dtype=torch.float,
+ align_corners=False,
+ multiplier=1,
+ precision=1e-3):
+ from mmcv.ops.point_sample import bilinear_grid_sample
+
+ input = torch.rand(1, 1, 20, 20, dtype=dtype)
+ grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
+ grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
+ grid *= multiplier
+
+ out = bilinear_grid_sample(input, grid, align_corners=align_corners)
+ ref_out = F.grid_sample(input, grid, align_corners=align_corners)
+
+ assert np.allclose(out.data.detach().cpu().numpy(),
+ ref_out.data.detach().cpu().numpy(), precision)
+
+ def test_bilinear_grid_sample(self):
+ self._test_bilinear_grid_sample(torch.double, False)
+ self._test_bilinear_grid_sample(torch.double, True)
+ self._test_bilinear_grid_sample(torch.float, False)
+ self._test_bilinear_grid_sample(torch.float, True)
+ self._test_bilinear_grid_sample(torch.float, False)
+ self._test_bilinear_grid_sample(torch.float, True, 5)
+ self._test_bilinear_grid_sample(torch.float, False, 10)
+ self._test_bilinear_grid_sample(torch.float, True, -6)
+ self._test_bilinear_grid_sample(torch.float, False, -10)
+ self._test_bilinear_grid_sample(torch.double, True, 5)
+ self._test_bilinear_grid_sample(torch.double, False, 10)
+ self._test_bilinear_grid_sample(torch.double, True, -6)
+ self._test_bilinear_grid_sample(torch.double, False, -10)
diff --git a/tests/test_ops/test_border_align.py b/tests/test_ops/test_border_align.py
new file mode 100644
index 0000000000..4821f3a9c1
--- /dev/null
+++ b/tests/test_ops/test_border_align.py
@@ -0,0 +1,90 @@
+import copy
+
+import numpy as np
+import pytest
+import torch
+
+# [1,4c,h,w]
+input_arr = [[[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]],
+ [[6, 7, 5, 8], [2, 1, 3, 4], [12, 9, 11, 10]],
+ [[-2, -3, 2, 0], [-4, -5, 1, -1], [-1, -1, -1, -1]],
+ [[0, -1, 2, 1], [-4, -3, -2, -1], [-1, -2, -3, -4]]]]
+# [1,h*w,4]
+boxes_arr = [[[0, 0, 2, 1], [1, 0, 3, 1], [1, 0, 2, 1], [0, 0, 3, 1],
+ [0, 0, 1, 2], [0, 0, 2, 2], [1, 0, 2, 1], [1, 0, 3, 1],
+ [0, 1, 1, 2], [0, 0, 3, 2], [1, 0, 3, 2], [2, 0, 3, 2]]]
+output_dict = {
+ # [1,c,h*w,4] for each value,
+ # the ouput is manually checked for its correctness
+
+ # pool_size=1
+ 1: [[[[3., 6., 1., 2.], [4., 7., -1., 1.], [3., 7., 1., 2.],
+ [4., 6., -1., 1.], [2., 12., -1., -1.], [3., 12., -1., 2.],
+ [3., 7., 1., 2.], [4., 7., -1., 1.], [6., 12., -1., -2.],
+ [4., 12., -1., 1.], [4., 9., -1., 1.], [4., 11., -1., 1.]]]],
+
+ # pool_size=2
+ 2: [[[[3., 6., 1., 2.], [4., 7., 1., 1.], [3., 7., 1., 2.],
+ [4., 6., -1., 1.], [2., 12., -1., -1.], [3., 12., -1., 2.],
+ [3., 7., 1., 2.], [4., 7., 1., 1.], [6., 12., -1., -2.],
+ [4., 12., -1., 1.], [4., 9., -1., 1.], [4., 11., -1., 1.]]]],
+}
+input_grad_dict = {
+ # [1,4c,h,w] for each value
+ # the grad is manually checked for its correctness
+
+ # pool_size=1
+ 1: [[[[0., 1., 4., 6.], [0., 1., 0., 0.], [0., 0., 0., 0.]],
+ [[2., 4., 0., 0.], [0., 0., 0., 0.], [4., 1., 1., 0.]],
+ [[0., 0., 0., 0.], [0., 0., 3., 3.], [0., 2., 1., 3.]],
+ [[0., 1., 4., 6.], [0., 0., 0., 0.], [0., 1., 0., 0.]]]],
+
+ # pool_size=2
+ 2: [[[[0., 1., 4., 6.], [0., 1., 0., 0.], [0., 0., 0., 0.]],
+ [[2., 4., 0., 0.], [0., 0., 0., 0.], [4., 1., 1., 0.]],
+ [[0., 0., 0., 0.], [0., 0., 5., 1.], [0., 2., 1., 3.]],
+ [[0., 1., 4., 6.], [0., 0., 0., 0.], [0., 1., 0., 0.]]]],
+}
+
+
+def _test_border_align_allclose(device, dtype, pool_size):
+ if not torch.cuda.is_available() and device == 'cuda':
+ pytest.skip('test requires GPU')
+ try:
+ from mmcv.ops import border_align, BorderAlign
+ except ModuleNotFoundError:
+ pytest.skip('BorderAlign op is not successfully compiled')
+
+ np_input = np.array(input_arr)
+ np_boxes = np.array(boxes_arr)
+ np_output = np.array(output_dict[pool_size])
+ np_grad = np.array(input_grad_dict[pool_size])
+
+ input = torch.tensor(
+ np_input, dtype=dtype, device=device, requires_grad=True)
+ boxes = torch.tensor(np_boxes, dtype=dtype, device=device)
+
+ # test for border_align
+ input_cp = copy.deepcopy(input)
+ output = border_align(input_cp, boxes, pool_size)
+ output.backward(torch.ones_like(output))
+ assert np.allclose(
+ output.data.type(dtype).cpu().numpy(), np_output, atol=1e-5)
+ assert np.allclose(
+ input_cp.grad.data.type(dtype).cpu().numpy(), np_grad, atol=1e-5)
+
+ # test for BorderAlign
+ pool_module = BorderAlign(pool_size)
+ output = pool_module(input, boxes)
+ output.backward(torch.ones_like(output))
+ assert np.allclose(
+ output.data.type(dtype).cpu().numpy(), np_output, atol=1e-5)
+ assert np.allclose(
+ input.grad.data.type(dtype).cpu().numpy(), np_grad, atol=1e-5)
+
+
+@pytest.mark.parametrize('device', ['cuda'])
+@pytest.mark.parametrize('dtype', [torch.float, torch.half, torch.double])
+@pytest.mark.parametrize('pool_size', [1, 2])
+def test_border_align(device, dtype, pool_size):
+ _test_border_align_allclose(device, dtype, pool_size)
diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py
index b99df8d011..ea6e429d2e 100644
--- a/tests/test_ops/test_deform_conv.py
+++ b/tests/test_ops/test_deform_conv.py
@@ -1,7 +1,18 @@
+from distutils.version import LooseVersion
+
import numpy as np
import pytest
import torch
+from mmcv.utils import TORCH_VERSION
+
+try:
+ # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
+ # would be imported and used; we should test if our modules support it.
+ from torch.cuda.amp import autocast
+except ImportError:
+ pass
+
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
[[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]],
@@ -71,7 +82,69 @@ def _test_deformconv(self, dtype=torch.float, threshold=1e-3):
with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3)
+ def _test_amp_deformconv(self, input_dtype, threshold=1e-3):
+ """The function to test amp released on pytorch 1.6.0.
+
+ The type of input data might be torch.float or torch.half,
+ so we should test deform_conv in both cases. With amp, the
+ data type of model will NOT be set manually.
+
+ Args:
+ input_dtype: torch.float or torch.half.
+ threshold: the same as above function.
+ """
+ if not torch.cuda.is_available():
+ return
+ from mmcv.ops import DeformConv2dPack
+ c_in = 1
+ c_out = 1
+ x = torch.Tensor(input).cuda().type(input_dtype)
+ x.requires_grad = True
+ model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
+ model.conv_offset.weight.data = torch.nn.Parameter(
+ torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
+ model.conv_offset.bias.data = torch.nn.Parameter(
+ torch.Tensor(offset_bias).reshape(8))
+ model.weight.data = torch.nn.Parameter(
+ torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
+ model.cuda()
+
+ out = model(x)
+ out.backward(torch.ones_like(out))
+
+ assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold)
+ assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold)
+ assert np.allclose(
+ model.conv_offset.weight.grad.detach().cpu().numpy(),
+ gt_offset_weight_grad, threshold)
+ assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(),
+ gt_offset_bias_grad, threshold)
+ assert np.allclose(model.weight.grad.detach().cpu().numpy(),
+ gt_deform_weight_grad, threshold)
+
+ from mmcv.ops import DeformConv2d
+ # test bias
+ model = DeformConv2d(1, 1, 2, stride=1, padding=0)
+ assert not hasattr(model, 'bias')
+ # test bias=True
+ with pytest.raises(AssertionError):
+ model = DeformConv2d(1, 1, 2, stride=1, padding=0, bias=True)
+ # test in_channels % group != 0
+ with pytest.raises(AssertionError):
+ model = DeformConv2d(3, 2, 3, groups=2)
+ # test out_channels % group != 0
+ with pytest.raises(AssertionError):
+ model = DeformConv2d(3, 4, 3, groups=3)
+
def test_deformconv(self):
self._test_deformconv(torch.double)
self._test_deformconv(torch.float)
self._test_deformconv(torch.half, 1e-1)
+
+ # test amp when torch version >= '1.6.0', the type of
+ # input data for deformconv might be torch.float or torch.half
+ if (TORCH_VERSION != 'parrots'
+ and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
+ with autocast(enabled=True):
+ self._test_amp_deformconv(torch.float, 1e-1)
+ self._test_amp_deformconv(torch.half, 1e-1)
diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py
index 43ddd66707..73032f0a45 100644
--- a/tests/test_ops/test_modulated_deform_conv.py
+++ b/tests/test_ops/test_modulated_deform_conv.py
@@ -1,8 +1,18 @@
import os
+from distutils.version import LooseVersion
import numpy
import torch
+from mmcv.utils import TORCH_VERSION
+
+try:
+ # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
+ # would be imported and used; we should test if our modules support it.
+ from torch.cuda.amp import autocast
+except ImportError:
+ pass
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
input_t = [[[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]]]
@@ -58,7 +68,53 @@ def _test_mdconv(self, dtype=torch.float):
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)
+ def _test_amp_mdconv(self, input_dtype=torch.float):
+ """The function to test amp released on pytorch 1.6.0.
+
+ The type of input data might be torch.float or torch.half,
+ so we should test mdconv in both cases. With amp, the data
+ type of model will NOT be set manually.
+
+ Args:
+ input_dtype: torch.float or torch.half.
+ """
+ if not torch.cuda.is_available():
+ return
+ from mmcv.ops import ModulatedDeformConv2dPack
+ input = torch.tensor(input_t).cuda().type(input_dtype)
+ input.requires_grad = True
+
+ dcn = ModulatedDeformConv2dPack(
+ 1,
+ 1,
+ kernel_size=(2, 2),
+ stride=1,
+ padding=1,
+ deform_groups=1,
+ bias=False).cuda()
+ dcn.weight.data.fill_(1.)
+ output = dcn(input)
+ output.sum().backward()
+ assert numpy.allclose(output.cpu().detach().numpy(), output_t, 1e-2)
+ assert numpy.allclose(input.grad.cpu().detach().numpy(), input_grad,
+ 1e-2)
+ assert numpy.allclose(dcn.weight.grad.cpu().detach().numpy(),
+ dcn_w_grad, 1e-2)
+ assert numpy.allclose(
+ dcn.conv_offset.weight.grad.cpu().detach().numpy(),
+ dcn_offset_w_grad, 1e-2)
+ assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
+ dcn_offset_b_grad, 1e-2)
+
def test_mdconv(self):
self._test_mdconv(torch.double)
self._test_mdconv(torch.float)
self._test_mdconv(torch.half)
+
+ # test amp when torch version >= '1.6.0', the type of
+ # input data for mdconv might be torch.float or torch.half
+ if (TORCH_VERSION != 'parrots'
+ and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
+ with autocast(enabled=True):
+ self._test_amp_mdconv(torch.float)
+ self._test_amp_mdconv(torch.half)
diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py
index 39d371fcb3..72aefcd108 100644
--- a/tests/test_ops/test_ms_deformable_attn.py
+++ b/tests/test_ops/test_ms_deformable_attn.py
@@ -1,9 +1,16 @@
import pytest
import torch
-from torch.autograd import gradcheck
from mmcv.ops.multi_scale_deform_attn import (
- MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
+ MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
+ multi_scale_deformable_attn_pytorch)
+
+_USING_PARROTS = True
+try:
+ from parrots.autograd import gradcheck
+except ImportError:
+ from torch.autograd import gradcheck
+ _USING_PARROTS = False
def test_forward_multi_scale_deformable_attn_pytorch():
@@ -92,7 +99,14 @@ def test_forward_equal_with_pytorch_float():
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
-@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096])
+@pytest.mark.parametrize('channels', [
+ 4,
+ 30,
+ 32,
+ 64,
+ 71,
+ 1025,
+])
def test_gradient_numerical(channels,
grad_value=True,
grad_sampling_loc=True,
@@ -118,8 +132,30 @@ def test_gradient_numerical(channels,
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
-
- assert gradcheck(
- func,
- (value.double(), shapes, level_start_index,
- sampling_locations.double(), attention_weights.double(), im2col_step))
+ if _USING_PARROTS:
+ assert gradcheck(
+ func, (value.double(), shapes, level_start_index,
+ sampling_locations.double(), attention_weights.double(),
+ im2col_step),
+ no_grads=[shapes, level_start_index])
+ else:
+ assert gradcheck(func, (value.double(), shapes, level_start_index,
+ sampling_locations.double(),
+ attention_weights.double(), im2col_step))
+
+
+def test_multiscale_deformable_attention():
+ with pytest.raises(ValueError):
+ # embed_dims must be divisible by num_heads,
+ MultiScaleDeformableAttention(
+ embed_dims=256,
+ num_heads=7,
+ )
+ with pytest.raises(ValueError):
+ # embed_dims must be divisible by num_heads,
+ MultiScaleDeformableAttention(
+ embed_dims=256,
+ num_heads=7,
+ )
+
+ MultiScaleDeformableAttention(embed_dims=256, num_heads=8)
diff --git a/tests/test_ops/test_nms.py b/tests/test_ops/test_nms.py
index 29090a94dc..3c59204b1b 100644
--- a/tests/test_ops/test_nms.py
+++ b/tests/test_ops/test_nms.py
@@ -138,7 +138,12 @@ def test_batched_nms(self):
from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl')
- nms_cfg = dict(type='nms', iou_threshold=0.7)
+ nms_max_num = 100
+ nms_cfg = dict(
+ type='nms',
+ iou_threshold=0.7,
+ score_threshold=0.5,
+ max_num=nms_max_num)
boxes, keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
@@ -156,7 +161,8 @@ def test_batched_nms(self):
assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
- assert torch.equal(keep, torch.from_numpy(results['keep']))
+ assert torch.equal(keep,
+ torch.from_numpy(results['keep'][:nms_max_num]))
nms_cfg = dict(type='soft_nms', iou_threshold=0.7)
boxes, keep = batched_nms(
diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py
index c07cd908d9..9e5fcb543e 100644
--- a/tests/test_ops/test_onnx.py
+++ b/tests/test_ops/test_onnx.py
@@ -23,31 +23,7 @@ def forward(self, *args, **kwargs):
return self.wrapped_function(*args, **kwargs)
-@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
-@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
-@pytest.mark.parametrize('align_corners', [True, False])
-def test_grid_sample(mode, padding_mode, align_corners):
- from mmcv.onnx.symbolic import register_extra_symbolics
- opset_version = 11
- register_extra_symbolics(opset_version)
-
- from mmcv.ops import get_onnxruntime_op_path
- ort_custom_op_path = get_onnxruntime_op_path()
- if not os.path.exists(ort_custom_op_path):
- pytest.skip('custom ops for onnxruntime are not compiled.')
-
- input = torch.rand(1, 1, 10, 10)
- grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
- grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
-
- def func(input, grid):
- return nn.functional.grid_sample(
- input,
- grid,
- mode=mode,
- padding_mode=padding_mode,
- align_corners=align_corners)
-
+def process_grid_sample(func, input, grid, ort_custom_op_path=''):
wrapped_model = WrapFunction(func).eval()
input_names = ['input', 'grid']
@@ -66,7 +42,8 @@ def func(input, grid):
onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions()
- session_options.register_custom_ops_library(ort_custom_op_path)
+ if ort_custom_op_path:
+ session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
@@ -83,6 +60,51 @@ def func(input, grid):
assert np.allclose(pytorch_results, ort_result, atol=1e-3)
+@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
+@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
+@pytest.mark.parametrize('align_corners', [True, False])
+def test_grid_sample(mode, padding_mode, align_corners):
+ from mmcv.onnx.symbolic import register_extra_symbolics
+ opset_version = 11
+ register_extra_symbolics(opset_version)
+
+ from mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ if not os.path.exists(ort_custom_op_path):
+ pytest.skip('custom ops for onnxruntime are not compiled.')
+
+ input = torch.rand(1, 1, 10, 10)
+ grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
+ grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
+
+ def func(input, grid):
+ return nn.functional.grid_sample(
+ input,
+ grid,
+ mode=mode,
+ padding_mode=padding_mode,
+ align_corners=align_corners)
+
+ return process_grid_sample(func, input, grid, ort_custom_op_path)
+
+
+@pytest.mark.parametrize('align_corners', [True, False])
+def test_bilinear_grid_sample(align_corners):
+ from mmcv.ops.point_sample import bilinear_grid_sample
+ # only support pytorch >= 1.5.0
+ if version.parse(torch.__version__) < version.parse('1.5.0'):
+ pytest.skip('Only support PyTorch >= 1.5.0')
+
+ input = torch.rand(1, 1, 10, 10)
+ grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
+ grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
+
+ def func(input, grid):
+ return bilinear_grid_sample(input, grid, align_corners=align_corners)
+
+ return process_grid_sample(func, input, grid)
+
+
def test_nms():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
@@ -93,9 +115,12 @@ def test_nms():
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
- pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
+
+ nms = partial(
+ nms, iou_threshold=0.3, offset=0, score_threshold=0, max_num=0)
+ pytorch_dets, _ = nms(boxes, scores)
pytorch_score = pytorch_dets[:, 4]
- nms = partial(nms, iou_threshold=0.3, offset=0)
+
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
@@ -106,14 +131,12 @@ def test_nms():
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
- onnx_model = onnx.load(onnx_file)
+ onnx_model = onnx.load(onnx_file)
ort_custom_op_path = get_onnxruntime_op_path()
- if not os.path.exists(ort_custom_op_path):
- pytest.skip('nms for onnxruntime is not compiled.')
-
session_options = rt.SessionOptions()
- session_options.register_custom_ops_library(ort_custom_op_path)
+ if os.path.exists(ort_custom_op_path):
+ session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py
index 3f8fe473c8..d65308ba8a 100644
--- a/tests/test_ops/test_tensorrt.py
+++ b/tests/test_ops/test_tensorrt.py
@@ -1,5 +1,6 @@
import os
from functools import partial
+from typing import Callable
import numpy as np
import onnx
@@ -8,7 +9,7 @@
import torch.nn as nn
try:
- from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
+ from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt,
save_trt_engine)
except ImportError:
pytest.skip(
@@ -94,7 +95,7 @@ def test_roialign():
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
- trt_model = TRTWraper(trt_file, ['input', 'rois'], ['roi_feat'])
+ trt_model = TRTWrapper(trt_file, ['input', 'rois'], ['roi_feat'])
with torch.no_grad():
trt_outputs = trt_model({'input': input, 'rois': rois})
@@ -125,7 +126,8 @@ def test_nms():
data = mmcv.load('./tests/data/batched_nms_data.pkl')
boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda()
- nms = partial(nms, iou_threshold=0.7, offset=0)
+ nms = partial(
+ nms, iou_threshold=0.7, offset=0, score_threshold=0.1, max_num=100)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
@@ -154,7 +156,7 @@ def test_nms():
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
- trt_model = TRTWraper(trt_file, ['boxes', 'scores'], ['dets', 'inds'])
+ trt_model = TRTWrapper(trt_file, ['boxes', 'scores'], ['dets', 'inds'])
with torch.no_grad():
trt_outputs = trt_model({'boxes': boxes, 'scores': scores})
@@ -194,7 +196,7 @@ def test_batched_nms():
fp16_mode = False
max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl')
- nms_cfg = dict(type='nms', iou_threshold=0.7)
+ nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1)
boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda()
idxs = torch.from_numpy(data['idxs']).cuda()
@@ -236,7 +238,7 @@ def test_batched_nms():
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
- trt_model = TRTWraper(trt_file, input_names, output_names)
+ trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({
@@ -310,7 +312,7 @@ def func(data):
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
- trt_model = TRTWraper(trt_file, input_names, output_names)
+ trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({'input': data.clone()})
@@ -386,7 +388,7 @@ def test_deform_conv():
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
- trt_model = TRTWraper(trt_file, input_names, output_names)
+ trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({'input': x.clone()})
@@ -404,6 +406,77 @@ def test_deform_conv():
assert torch.allclose(pytorch_results, trt_results)
+@pytest.mark.parametrize('with_bias', [True, False])
+def test_modulated_deform_conv(with_bias):
+ try:
+ from mmcv.ops import ModulatedDeformConv2dPack
+ except (ImportError, ModuleNotFoundError):
+ pytest.skip('test requires compilation')
+
+ input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
+
+ x = torch.Tensor(input).cuda()
+ model = ModulatedDeformConv2dPack(
+ 1,
+ 1,
+ kernel_size=(2, 2),
+ stride=1,
+ padding=1,
+ deform_groups=1,
+ bias=with_bias)
+ model.weight.data.fill_(1.)
+ model.type(torch.float32)
+ model = model.cuda().eval()
+
+ input_names = ['input']
+ output_names = ['output']
+
+ with torch.no_grad():
+ torch.onnx.export(
+ model, (x.clone(), ),
+ onnx_file,
+ export_params=True,
+ keep_initializers_as_inputs=True,
+ input_names=input_names,
+ output_names=output_names,
+ opset_version=11)
+
+ onnx_model = onnx.load(onnx_file)
+
+ # create trt engine and wraper
+ opt_shape_dict = {
+ 'input': [list(x.shape), list(x.shape),
+ list(x.shape)],
+ }
+ # trt config
+ fp16_mode = False
+ max_workspace_size = 1 << 30
+
+ trt_engine = onnx2trt(
+ onnx_model,
+ opt_shape_dict,
+ fp16_mode=fp16_mode,
+ max_workspace_size=max_workspace_size)
+
+ save_trt_engine(trt_engine, trt_file)
+ trt_model = TRTWrapper(trt_file, input_names, output_names)
+
+ with torch.no_grad():
+ trt_outputs = trt_model({'input': x.clone()})
+ trt_results = trt_outputs['output']
+
+ # compute pytorch_output
+ with torch.no_grad():
+ pytorch_results = model(x.clone())
+
+ # allclose
+ if os.path.exists(onnx_file):
+ os.remove(onnx_file)
+ if os.path.exists(trt_file):
+ os.remove(trt_file)
+ torch.testing.assert_allclose(pytorch_results, trt_results)
+
+
@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
@@ -462,7 +535,7 @@ def func(input, grid):
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
- trt_model = TRTWraper(trt_file, input_names, output_names)
+ trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({'input': input.clone(), 'grid': grid.clone()})
@@ -478,3 +551,179 @@ def func(input, grid):
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results)
+
+
+@pytest.mark.parametrize('func', [torch.cummax, torch.cummin])
+def test_cummin_cummax(func: Callable):
+ # Note generally `cummax` or `cummin` is exportable to ONNX
+ # as long as the pytorch version >= 1.5.0, since `torch.cummax`
+ # is only supported with torch >= 1.5.0.
+ # But when `cummax` or `cummin` serves as an intermediate component
+ # whose outputs is used as inputs for another modules, it's expected
+ # that pytorch version must be >= 1.7.0. Otherwise error appears like:
+ # `RuntimeError: tuple appears in op that does not forward tuples,
+ # unsupported 'kind: prim::PythonOp`.
+ from packaging import version
+ if version.parse(torch.__version__) < version.parse('1.7.0'):
+ pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0')
+
+ opset = 11
+ # register custom op `mmcv::cummax` and `mmcv::cummin`
+ from mmcv.onnx.symbolic import register_extra_symbolics
+ register_extra_symbolics(opset)
+
+ input_list = [
+ # arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
+ torch.rand((2, 3, 4, 1, 5)).cuda(),
+ torch.rand((1)).cuda()
+ ]
+
+ input_names = ['input']
+ output_names = ['output', 'indices']
+
+ for input in input_list:
+ ndims = input.dim()
+ # valid dim range is [-ndims, ndims-1]
+ # test for all `dim` value which is valid
+ for dim in range(-ndims, ndims):
+ cummax_func = partial(func, dim=dim)
+ wrapped_model = WrapFunction(cummax_func).eval().cuda()
+
+ with torch.no_grad():
+ torch.onnx.export(
+ wrapped_model,
+ input,
+ onnx_file,
+ export_params=True,
+ keep_initializers_as_inputs=False,
+ input_names=input_names,
+ output_names=output_names,
+ opset_version=opset)
+
+ onnx_model = onnx.load(onnx_file)
+
+ # create trt engine and wraper
+ opt_shape_dict = {
+ 'input':
+ [list(input.shape),
+ list(input.shape),
+ list(input.shape)]
+ }
+ # trt config
+ fp16_mode = False
+ max_workspace_size = 1 << 30
+
+ trt_engine = onnx2trt(
+ onnx_model,
+ opt_shape_dict,
+ fp16_mode=fp16_mode,
+ max_workspace_size=max_workspace_size)
+
+ # remove ONNX model after conversion
+ if os.path.exists(onnx_file):
+ os.remove(onnx_file)
+
+ # save TensorRT model
+ save_trt_engine(trt_engine, trt_file)
+
+ # load and wrap TensorRT model
+ trt_model = TRTWrapper(trt_file)
+
+ # remove trt model after loading
+ if os.path.exists(trt_file):
+ os.remove(trt_file)
+
+ # compute trt output
+ with torch.no_grad():
+ trt_results = trt_model({'input': input.contiguous().clone()})
+ trt_output = trt_results['output']
+ trt_indices = trt_results['indices']
+
+ # compute pytorch output
+ with torch.no_grad():
+ pytorch_results = wrapped_model(input.clone())
+ pytorch_output = pytorch_results[0]
+ pytorch_indices = pytorch_results[1]
+
+ torch.testing.assert_allclose(trt_output, pytorch_output)
+ torch.testing.assert_allclose(trt_indices, pytorch_indices)
+
+
+@pytest.mark.parametrize('dynamic_export', [True, False])
+@pytest.mark.parametrize('fp16_mode', [True, False])
+def test_instance_norm(dynamic_export, fp16_mode):
+
+ n, c, h, w = 2, 3, 10, 10
+ data = torch.randn(n, c, h, w).cuda()
+ norm = nn.InstanceNorm2d(c, affine=True)
+
+ wrapped_model = WrapFunction(norm).eval().cuda()
+
+ input_names = ['input']
+ output_names = ['output']
+ dynamic_axes = None
+ if dynamic_export:
+ dynamic_axes = {
+ 'input': {
+ 0: 'n',
+ 2: 'h',
+ 3: 'w',
+ },
+ 'output': {
+ 0: 'n',
+ 2: 'h',
+ 3: 'w',
+ },
+ }
+ with torch.no_grad():
+ torch.onnx.export(
+ wrapped_model, (data.clone(), ),
+ onnx_file,
+ export_params=True,
+ keep_initializers_as_inputs=True,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ opset_version=11)
+
+ onnx_model = onnx.load(onnx_file)
+
+ # create trt engine and wraper
+ if dynamic_export:
+ opt_shape_dict = {
+ 'input':
+ [list(data.shape),
+ list(data.shape), [2 * n, c, 2 * h, 2 * w]],
+ }
+ else:
+ opt_shape_dict = {
+ 'input': [list(data.shape),
+ list(data.shape),
+ list(data.shape)],
+ }
+ # trt config
+ max_workspace_size = 1 << 30
+
+ trt_engine = onnx2trt(
+ onnx_model,
+ opt_shape_dict,
+ fp16_mode=fp16_mode,
+ max_workspace_size=max_workspace_size)
+
+ save_trt_engine(trt_engine, trt_file)
+ trt_model = TRTWrapper(trt_file, input_names, output_names)
+
+ with torch.no_grad():
+ trt_outputs = trt_model({'input': data.clone()})
+ trt_results = trt_outputs['output']
+
+ # compute pytorch_output
+ with torch.no_grad():
+ pytorch_results = wrapped_model(data.clone())
+
+ # allclose
+ if os.path.exists(onnx_file):
+ os.remove(onnx_file)
+ if os.path.exists(trt_file):
+ os.remove(trt_file)
+ assert torch.allclose(pytorch_results, trt_results)
diff --git a/tests/test_ops/test_tensorrt_preprocess.py b/tests/test_ops/test_tensorrt_preprocess.py
new file mode 100644
index 0000000000..b5ade24b4b
--- /dev/null
+++ b/tests/test_ops/test_tensorrt_preprocess.py
@@ -0,0 +1,75 @@
+import os
+from functools import wraps
+
+import onnx
+import torch
+
+from mmcv.ops import nms
+from mmcv.tensorrt.preprocess import preprocess_onnx
+
+
+def remove_tmp_file(func):
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ onnx_file = 'tmp.onnx'
+ kwargs['onnx_file'] = onnx_file
+ try:
+ result = func(*args, **kwargs)
+ finally:
+ if os.path.exists(onnx_file):
+ os.remove(onnx_file)
+ return result
+
+ return wrapper
+
+
+@remove_tmp_file
+def export_nms_module_to_onnx(module, onnx_file):
+ torch_model = module()
+ torch_model.eval()
+
+ input = (torch.rand([100, 4], dtype=torch.float32),
+ torch.rand([100], dtype=torch.float32))
+
+ torch.onnx.export(
+ torch_model,
+ input,
+ onnx_file,
+ opset_version=11,
+ input_names=['boxes', 'scores'],
+ output_names=['output'])
+
+ onnx_model = onnx.load(onnx_file)
+ return onnx_model
+
+
+def test_can_handle_nms_with_constant_maxnum():
+
+ class ModuleNMS(torch.nn.Module):
+
+ def forward(self, boxes, scores):
+ return nms(boxes, scores, iou_threshold=0.4, max_num=10)
+
+ onnx_model = export_nms_module_to_onnx(ModuleNMS)
+ preprocess_onnx_model = preprocess_onnx(onnx_model)
+ for node in preprocess_onnx_model.graph.node:
+ if 'NonMaxSuppression' in node.name:
+ assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'
+
+
+def test_can_handle_nms_with_undefined_maxnum():
+
+ class ModuleNMS(torch.nn.Module):
+
+ def forward(self, boxes, scores):
+ return nms(boxes, scores, iou_threshold=0.4)
+
+ onnx_model = export_nms_module_to_onnx(ModuleNMS)
+ preprocess_onnx_model = preprocess_onnx(onnx_model)
+ for node in preprocess_onnx_model.graph.node:
+ if 'NonMaxSuppression' in node.name:
+ assert len(node.attribute) == 5, \
+ 'The NMS must have 5 attributes.'
+ assert node.attribute[2].i > 0, \
+ 'The max_output_boxes_per_class is not defined correctly.'
diff --git a/tests/test_parallel.py b/tests/test_parallel.py
index 93c8f57054..7d73aa81d8 100644
--- a/tests/test_parallel.py
+++ b/tests/test_parallel.py
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
+import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
@@ -15,7 +16,7 @@ def mock(*args, **kwargs):
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
-@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
+@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():
class Model(nn.Module):
@@ -27,6 +28,12 @@ def __init__(self):
def forward(self, x):
return self.conv(x)
+ # _verify_model_across_ranks is added in torch1.9.0 so we should check
+ # wether _verify_model_across_ranks is the member of torch.distributed
+ # before mocking
+ if hasattr(torch.distributed, '_verify_model_across_ranks'):
+ torch.distributed._verify_model_across_ranks = mock
+
model = Model()
assert not is_module_wrapper(model)
diff --git a/tests/test_runner/test_eval_hook.py b/tests/test_runner/test_eval_hook.py
index b778cf2526..004a2ad113 100644
--- a/tests/test_runner/test_eval_hook.py
+++ b/tests/test_runner/test_eval_hook.py
@@ -84,8 +84,8 @@ def _build_iter_runner():
class EvalHook(BaseEvalHook):
- greater_keys = ['acc', 'top']
- less_keys = ['loss', 'loss_top']
+ _default_greater_keys = ['acc', 'top']
+ _default_less_keys = ['loss', 'loss_top']
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -273,6 +273,31 @@ def test_eval_hook():
assert runner.meta['hook_msgs']['best_score'] == 7
assert not osp.exists(old_ckpt_path)
+ # test EvalHook with customer test_fn and greater/less keys
+ loader = DataLoader(EvalDataset())
+ model = Model()
+ data_loader = DataLoader(EvalDataset())
+
+ eval_hook = EvalHook(
+ data_loader,
+ save_best='acc',
+ test_fn=mock.MagicMock(return_value={}),
+ greater_keys=[],
+ less_keys=['acc'])
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ logger = get_logger('test_eval')
+ runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
+ runner.register_checkpoint_hook(dict(interval=1))
+ runner.register_hook(eval_hook)
+ runner.run([loader], [('train', 1)], 8)
+
+ ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')
+
+ assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
+ assert osp.exists(ckpt_path)
+ assert runner.meta['hook_msgs']['best_score'] == -3
+
@patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock)
diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py
index 13a0514feb..5a2e0d906a 100644
--- a/tests/test_runner/test_hooks.py
+++ b/tests/test_runner/test_hooks.py
@@ -6,6 +6,7 @@
"""
import logging
import os.path as osp
+import random
import re
import shutil
import sys
@@ -15,16 +16,18 @@
import pytest
import torch
import torch.nn as nn
+import torch.utils.data as Data
from torch.nn.init import constant_
from torch.utils.data import DataLoader
-from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
- MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
- build_runner)
+from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
+ IterTimerHook, MlflowLoggerHook, NeptuneLoggerHook,
+ PaviLoggerHook, WandbLoggerHook, build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
OneCycleLrUpdaterHook,
+ ReduceLrUpdateHook,
StepLrUpdaterHook)
@@ -149,10 +152,27 @@ def __init__(self, info, *args, **kwargs):
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)
+ runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
+ # test custom_hooks with string priority setting
+ priority_ranks = [
+ 'HIGHEST', 'VERY_HIGH', 'HIGH', 'ABOVE_NORMAL', 'NORMAL',
+ 'BELOW_NORMAL', 'LOW', 'VERY_LOW', 'LOWEST'
+ ]
+ random_priority_ranks = priority_ranks.copy()
+ random.shuffle(random_priority_ranks)
+ custom_hooks_cfg = [
+ dict(type='ToyHook', priority=rank, info=rank)
+ for rank in random_priority_ranks
+ ]
+ runner.register_custom_hooks(custom_hooks_cfg)
+ assert [hook.info for hook in runner.hooks] == priority_ranks
+ shutil.rmtree(runner.work_dir)
+
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
+ dict(type='ToyHook', priority='NORMAL', info='custom normal'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
@@ -163,9 +183,11 @@ def __init__(self, info, *args, **kwargs):
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
+ # If custom hooks have same priority with default hooks, custom hooks
+ # will be triggered after default hooks.
hooks_order = [
- 'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
- 'custom 89', 'log'
+ 'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint',
+ 'custom normal', 'timer', 'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)
@@ -869,6 +891,116 @@ def test_cyclic_lr_update_hook(multi_optimizers, max_iters):
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
+@pytest.mark.parametrize('multi_optimziers', (True, False))
+def test_reduce_lr_update_hook(multi_optimziers):
+ """Test ReduceLrUpdateHook."""
+ with pytest.raises(TypeError):
+ # periods should be specified
+ ReduceLrUpdateHook()
+
+ with pytest.raises(AssertionError):
+ # periods should be list
+ ReduceLrUpdateHook(periods=1)
+
+ with pytest.raises(AssertionError):
+ # periods should all be positive
+ ReduceLrUpdateHook(periods=[1, 2, -2])
+
+ with pytest.raises(ValueError):
+ # mode should be either 'min' or 'max'
+ ReduceLrUpdateHook(periods=[0, 1], mode='sum')
+
+ with pytest.raises(ValueError):
+ # factor should be < 1.0
+ ReduceLrUpdateHook(periods=[0, 1], mode='min', factor=1.0)
+
+ with pytest.raises(ValueError):
+ # threshold_mode should be 'rel' or 'abs'
+ ReduceLrUpdateHook(
+ periods=[0, 1], mode='min', factor=0.1, threshold_mode='sum')
+
+ sys.modules['pavi'] = MagicMock()
+ x = torch.ones((30, 1))
+ y = torch.ones((30, 1)) * 5
+ loader = DataLoader(Data.TensorDataset(x, y))
+ runner = _build_reduceLR_runner(
+ runner_type='IterBasedRunner',
+ multi_optimziers=multi_optimziers,
+ max_iters=30,
+ max_epochs=None)
+
+ hook = ReduceLrUpdateHook(
+ periods=list(range(30)),
+ mode='min',
+ factor=0.1,
+ patience=2,
+ threshold=1e-4,
+ threshold_mode='rel',
+ by_epoch=False,
+ eps=1e-4)
+ runner.register_hook(hook)
+ runner.register_hook_from_cfg(dict(type='IterTimerHook'))
+ runner.register_hook(IterTimerHook())
+ # add pavi hook
+ hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
+ runner.register_hook(hook)
+ runner.run([loader], [('train', 1)])
+ shutil.rmtree(runner.work_dir)
+
+ assert hasattr(hook, 'writer')
+ if multi_optimziers:
+ calls = [
+ call(
+ 'train', {
+ 'learning_rate/model1': 0.5,
+ 'learning_rate/model2': 0.01,
+ 'momentum/model1': 0.9,
+ 'momentum/model2': 0.95,
+ }, 1),
+ call(
+ 'train', {
+ 'learning_rate/model1': 0.05,
+ 'learning_rate/model2': 0.01,
+ 'momentum/model1': 0.9,
+ 'momentum/model2': 0.95,
+ }, 19),
+ call(
+ 'train', {
+ 'learning_rate/model1': 0.005000000000000001,
+ 'learning_rate/model2': 0.01,
+ 'momentum/model1': 0.9,
+ 'momentum/model2': 0.95,
+ }, 22),
+ call(
+ 'train', {
+ 'learning_rate/model1': 5.0000000000000016e-05,
+ 'learning_rate/model2': 0.01,
+ 'momentum/model1': 0.9,
+ 'momentum/model2': 0.95,
+ }, 28)
+ ]
+ else:
+ calls = [
+ call('train', {
+ 'learning_rate': 0.5,
+ 'momentum': 0.9
+ }, 1),
+ call('train', {
+ 'learning_rate': 0.05,
+ 'momentum': 0.9
+ }, 19),
+ call('train', {
+ 'learning_rate': 0.005000000000000001,
+ 'momentum': 0.9
+ }, 22),
+ call('train', {
+ 'learning_rate': 5.0000000000000016e-05,
+ 'momentum': 0.9
+ }, 28)
+ ]
+ hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
+
+
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
@@ -915,6 +1047,40 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with()
+def test_neptune_hook():
+ sys.modules['neptune'] = MagicMock()
+ sys.modules['neptune.new'] = MagicMock()
+ runner = _build_demo_runner()
+ hook = NeptuneLoggerHook()
+
+ loader = DataLoader(torch.ones((5, 2)))
+
+ runner.register_hook(hook)
+ runner.run([loader, loader], [('train', 1), ('val', 1)])
+ shutil.rmtree(runner.work_dir)
+
+ hook.neptune.init.assert_called_with()
+ hook.run['momentum'].log.assert_called_with(0.95, step=6)
+ hook.run.stop.assert_called_with()
+
+
+def test_dvclive_hook(tmp_path):
+ sys.modules['dvclive'] = MagicMock()
+ runner = _build_demo_runner()
+
+ (tmp_path / 'dvclive').mkdir()
+ hook = DvcliveLoggerHook(str(tmp_path / 'dvclive'))
+ loader = DataLoader(torch.ones((5, 2)))
+
+ runner.register_hook(hook)
+ runner.run([loader, loader], [('train', 1), ('val', 1)])
+ shutil.rmtree(runner.work_dir)
+
+ hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive'))
+ hook.dvclive.log.assert_called_with('momentum', 0.95, step=6)
+ hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6)
+
+
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
@@ -961,6 +1127,69 @@ def val_step(self, x, optimizer, **kwargs):
return runner
+def _build_reduceLR_runner_without_hook(runner_type='EpochBasedRunner',
+ max_epochs=1,
+ max_iters=None,
+ multi_optimziers=False):
+
+ class Model(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(1, 1)
+ self.conv = nn.Conv2d(3, 3, 3)
+ torch.nn.init.constant_(self.linear.weight, 1)
+ torch.nn.init.constant_(self.linear.bias, 1)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ def train_step(self, x, optimizer, **kwargs):
+ if isinstance(optimizer, dict):
+ for name, optim in optimizer.items():
+ optim.zero_grad()
+ else:
+ optimizer.zero_grad()
+ loss_fn = torch.nn.MSELoss()
+ pred = self.forward(x[0])
+ loss_ = loss_fn(pred, x[1])
+ loss_.backward()
+ if isinstance(optimizer, dict):
+ for name, optim in optimizer.items():
+ optim.step()
+ else:
+ optimizer.step()
+ return dict(loss=loss_)
+
+ def val_step(self, x, optimizer, **kwargs):
+ loss_fn = torch.nn.MSELoss()
+ return dict(loss=loss_fn(self.forward(x[0]), x[1]))
+
+ model = Model()
+
+ if multi_optimziers:
+ optimizer = {
+ 'model1':
+ torch.optim.SGD(model.linear.parameters(), lr=0.5, momentum=0.9),
+ 'model2':
+ torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.95),
+ }
+ else:
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.5, momentum=0.9)
+
+ tmp_dir = tempfile.mkdtemp()
+ runner = build_runner(
+ dict(type=runner_type),
+ default_args=dict(
+ model=model,
+ work_dir=tmp_dir,
+ optimizer=optimizer,
+ logger=logging.getLogger(),
+ max_epochs=max_epochs,
+ max_iters=max_iters))
+ return runner
+
+
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
@@ -979,6 +1208,24 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
return runner
+def _build_reduceLR_runner(runner_type='EpochBasedRunner',
+ max_epochs=1,
+ max_iters=None,
+ multi_optimziers=False):
+
+ log_config = dict(
+ interval=1, hooks=[
+ dict(type='TextLoggerHook'),
+ ])
+
+ runner = _build_reduceLR_runner_without_hook(runner_type, max_epochs,
+ max_iters, multi_optimziers)
+
+ runner.register_checkpoint_hook(dict(interval=1))
+ runner.register_logger_hooks(log_config)
+ return runner
+
+
def test_runner_with_revise_keys():
import os
@@ -1016,3 +1263,20 @@ def __init__(self):
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
+
+
+def test_get_triggered_stages():
+
+ class ToyHook(Hook):
+ # test normal stage
+ def before_run():
+ pass
+
+ # test the method mapped to multi stages.
+ def after_epoch():
+ pass
+
+ hook = ToyHook()
+ # stages output have order, so here is list instead of set.
+ expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch']
+ assert hook.get_triggered_stages() == expected_stages
diff --git a/tests/test_utils/test_config.py b/tests/test_utils/test_config.py
index 5abafe80b8..44a67ba500 100644
--- a/tests/test_utils/test_config.py
+++ b/tests/test_utils/test_config.py
@@ -224,6 +224,81 @@ def test_merge_from_multiple_bases():
Config.fromfile(osp.join(data_path, 'config/m.py'))
+def test_base_variables():
+ for file in ['t.py', 't.json', 't.yaml']:
+ cfg_file = osp.join(data_path, f'config/{file}')
+ cfg = Config.fromfile(cfg_file)
+ assert isinstance(cfg, Config)
+ assert cfg.filename == cfg_file
+ # cfg.field
+ assert cfg.item1 == [1, 2]
+ assert cfg.item2.a == 0
+ assert cfg.item3 is False
+ assert cfg.item4 == 'test'
+ assert cfg.item5 == dict(a=0, b=1)
+ assert cfg.item6 == [dict(a=0), dict(b=1)]
+ assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
+ assert cfg.item8 == file
+ assert cfg.item9 == dict(a=0)
+ assert cfg.item10 == [3.1, 4.2, 5.3]
+
+ # test nested base
+ for file in ['u.py', 'u.json', 'u.yaml']:
+ cfg_file = osp.join(data_path, f'config/{file}')
+ cfg = Config.fromfile(cfg_file)
+ assert isinstance(cfg, Config)
+ assert cfg.filename == cfg_file
+ # cfg.field
+ assert cfg.base == '_base_.item8'
+ assert cfg.item1 == [1, 2]
+ assert cfg.item2.a == 0
+ assert cfg.item3 is False
+ assert cfg.item4 == 'test'
+ assert cfg.item5 == dict(a=0, b=1)
+ assert cfg.item6 == [dict(a=0), dict(b=1)]
+ assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
+ assert cfg.item8 == 't.py'
+ assert cfg.item9 == dict(a=0)
+ assert cfg.item10 == [3.1, 4.2, 5.3]
+ assert cfg.item11 == 't.py'
+ assert cfg.item12 == dict(a=0)
+ assert cfg.item13 == [3.1, 4.2, 5.3]
+ assert cfg.item14 == [1, 2]
+ assert cfg.item15 == dict(
+ a=dict(b=dict(a=0)),
+ b=[False],
+ c=['test'],
+ d=[[{
+ 'e': 0
+ }], [{
+ 'a': 0
+ }, {
+ 'b': 1
+ }]],
+ e=[1, 2])
+
+ # test reference assignment for py
+ cfg_file = osp.join(data_path, 'config/v.py')
+ cfg = Config.fromfile(cfg_file)
+ assert isinstance(cfg, Config)
+ assert cfg.filename == cfg_file
+ assert cfg.item21 == 't.py'
+ assert cfg.item22 == 't.py'
+ assert cfg.item23 == [3.1, 4.2, 5.3]
+ assert cfg.item24 == [3.1, 4.2, 5.3]
+ assert cfg.item25 == dict(
+ a=dict(b=[3.1, 4.2, 5.3]),
+ b=[[3.1, 4.2, 5.3]],
+ c=[[{
+ 'e': 't.py'
+ }], [{
+ 'a': 0
+ }, {
+ 'b': 1
+ }]],
+ e='t.py')
+
+
def test_merge_recursive_bases():
cfg_file = osp.join(data_path, 'config/f.py')
cfg = Config.fromfile(cfg_file)
diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py
index adcd26ea0d..7b056554af 100644
--- a/tests/test_utils/test_misc.py
+++ b/tests/test_utils/test_misc.py
@@ -4,6 +4,31 @@
import mmcv
+def test_to_ntuple():
+ single_number = 2
+ assert mmcv.utils.to_1tuple(single_number) == (single_number, )
+ assert mmcv.utils.to_2tuple(single_number) == (single_number,
+ single_number)
+ assert mmcv.utils.to_3tuple(single_number) == (single_number,
+ single_number,
+ single_number)
+ assert mmcv.utils.to_4tuple(single_number) == (single_number,
+ single_number,
+ single_number,
+ single_number)
+ assert mmcv.utils.to_ntuple(5)(single_number) == (single_number,
+ single_number,
+ single_number,
+ single_number,
+ single_number)
+ assert mmcv.utils.to_ntuple(6)(single_number) == (single_number,
+ single_number,
+ single_number,
+ single_number,
+ single_number,
+ single_number)
+
+
def test_iter_cast():
assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3]
assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0]
@@ -105,6 +130,7 @@ def func_c():
def test_import_modules_from_strings():
# multiple imports
import os.path as osp_
+
import sys as sys_
osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys'])
assert osp == osp_
@@ -134,3 +160,33 @@ def test_import_modules_from_strings():
['os.path', '_not_implemented'], allow_failed_imports=True)
assert imported[0] == osp
assert imported[1] is None
+
+
+def test_is_method_overridden():
+
+ class Base:
+
+ def foo1():
+ pass
+
+ def foo2():
+ pass
+
+ class Sub(Base):
+
+ def foo1():
+ pass
+
+ # test passing sub class directly
+ assert mmcv.is_method_overridden('foo1', Base, Sub)
+ assert not mmcv.is_method_overridden('foo2', Base, Sub)
+
+ # test passing instance of sub class
+ sub_instance = Sub()
+ assert mmcv.is_method_overridden('foo1', Base, sub_instance)
+ assert not mmcv.is_method_overridden('foo2', Base, sub_instance)
+
+ # base_class should be a class, not instance
+ base_instance = Base()
+ with pytest.raises(AssertionError):
+ mmcv.is_method_overridden('foo1', base_instance, sub_instance)
diff --git a/tests/test_utils/test_path.py b/tests/test_utils/test_path.py
index 42f308ef66..aa6537eafa 100644
--- a/tests/test_utils/test_path.py
+++ b/tests/test_utils/test_path.py
@@ -40,12 +40,13 @@ def test_scandir():
filenames_recursive = [
'a.bin', '1.txt', '2.txt', '1.json', '2.json', 'sub/1.json',
- 'sub/1.txt'
+ 'sub/1.txt', '.file'
]
- assert set(mmcv.scandir(folder,
- recursive=True)) == set(filenames_recursive)
- assert set(mmcv.scandir(Path(folder),
- recursive=True)) == set(filenames_recursive)
+ # .file starts with '.' and is a file so it will not be scanned
+ assert set(mmcv.scandir(folder, recursive=True)) == set(
+ [filename for filename in filenames_recursive if filename != '.file'])
+ assert set(mmcv.scandir(Path(folder), recursive=True)) == set(
+ [filename for filename in filenames_recursive if filename != '.file'])
assert set(mmcv.scandir(folder, '.txt', recursive=True)) == set([
filename for filename in filenames_recursive
if filename.endswith('.txt')