diff --git a/mmcv/ops/csrc/carafe_cuda_kernel.cuh b/mmcv/ops/csrc/carafe_cuda_kernel.cuh index e9b569d3b5..4bf11694f3 100644 --- a/mmcv/ops/csrc/carafe_cuda_kernel.cuh +++ b/mmcv/ops/csrc/carafe_cuda_kernel.cuh @@ -7,7 +7,11 @@ #include "pytorch_cuda_helper.hpp" #endif +#ifdef HIP_DIFF +#define WARP_SIZE 64 +#else #define WARP_SIZE 32 +#endif #define THREADS_PER_PIXEL 32 #define MAX_SHARED_MEMORY 49152 #define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 @@ -24,6 +28,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h, int index = w + (h + (c + n * channel_num) * height) * width; return index; } +#ifndef HIP_DIFF /* TODO: move this to a common place */ template __device__ inline scalar_t min(scalar_t a, scalar_t b) { @@ -34,19 +39,27 @@ template __device__ inline scalar_t max(scalar_t a, scalar_t b) { return a > b ? a : b; } - +#endif template __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { - for (int offset = 16; offset > 0; offset /= 2) + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) +#ifdef HIP_DIFF + val += __shfl_down(val, offset); +#else val += __shfl_down_sync(FULL_MASK, val, offset); +#endif return val; } template <> __device__ __forceinline__ phalf warpReduceSum(phalf val) { - for (int offset = 16; offset > 0; offset /= 2) + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) +#ifdef HIP_DIFF + __PHALF(val) += __shfl_down(FULL_MASK, val, offset); +#else __PHALF(val) += __shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset); +#endif return val; } @@ -302,7 +315,11 @@ __global__ void CARAFEBackward_Mask(const int num_kernels, output_val += top_diff[top_id] * bottom_data[bottom_id]; } } +#ifdef HIP_DIFF + __syncthreads(); +#else __syncwarp(); +#endif output_val = warpReduceSum(output_val); if (lane_id == 0) { const int mask_id = diff --git a/mmcv/ops/csrc/pytorch/info.cpp b/mmcv/ops/csrc/pytorch/info.cpp index a2ebafa843..fd01c2e371 100644 --- a/mmcv/ops/csrc/pytorch/info.cpp +++ b/mmcv/ops/csrc/pytorch/info.cpp @@ -3,12 +3,15 @@ #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA +#ifndef HIP_DIFF #include int get_cudart_version() { return CUDART_VERSION; } #endif +#endif std::string get_compiling_cuda_version() { #ifdef MMCV_WITH_CUDA +#ifndef HIP_DIFF std::ostringstream oss; // copied from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 @@ -20,6 +23,9 @@ std::string get_compiling_cuda_version() { }; printCudaStyleVersion(get_cudart_version()); return oss.str(); +#else + return std::string("rocm not vailable"); +#endif #else return std::string("not available"); #endif diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index ccc22b09e1..4cb4be98f2 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -1,15 +1,26 @@ from functools import partial +from pkg_resources import parse_version import torch TORCH_VERSION = torch.__version__ +is_rocm_pytorch = False +if parse_version(TORCH_VERSION) >= parse_version('1.5'): + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + def _get_cuda_home(): if TORCH_VERSION == 'parrots': from parrots.utils.build_extension import CUDA_HOME else: - from torch.utils.cpp_extension import CUDA_HOME + if is_rocm_pytorch: + from torch.utils.cpp_extension import ROCM_HOME + CUDA_HOME = ROCM_HOME + else: + from torch.utils.cpp_extension import CUDA_HOME return CUDA_HOME diff --git a/setup.py b/setup.py index 2b1ae6dc18..4122c814d2 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import glob import os import re -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound, get_distribution, parse_version from setuptools import find_packages, setup EXT_TYPE = '' @@ -220,18 +220,44 @@ def get_extensions(): define_macros = [] extra_compile_args = {'cxx': []} - if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + is_rocm_pytorch = False + if parse_version(torch.__version__) >= parse_version('1.5'): + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + + this_dir = 'mmcv/ops/csrc/' + if is_rocm_pytorch: + from torch.utils.hipify import hipify_python + + hipify_python.hipify( + project_directory=this_dir, + output_directory=this_dir, + includes='mmcv/ops/csrc/*', + show_detailed=True, + is_pytorch_extension=True, + ) + define_macros += [('MMCV_WITH_CUDA', None)] + define_macros += [('HIP_DIFF', None)] + cuda_args = os.getenv('MMCV_CUDA_ARGS') + extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] + op_files = glob.glob('./mmcv/ops/csrc/pytorch/hip/*') + extension = CUDAExtension + include_path = os.path.abspath('./mmcv/ops/csrc/hip') + + elif torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': define_macros += [('MMCV_WITH_CUDA', None)] cuda_args = os.getenv('MMCV_CUDA_ARGS') extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] op_files = glob.glob('./mmcv/ops/csrc/pytorch/*') extension = CUDAExtension + include_path = os.path.abspath('./mmcv/ops/csrc') else: print(f'Compiling {ext_name} without CUDA') op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') extension = CppExtension + include_path = os.path.abspath('./mmcv/ops/csrc') - include_path = os.path.abspath('./mmcv/ops/csrc') ext_ops = extension( name=ext_name, sources=op_files, diff --git a/tests/test_runner/test_utils.py b/tests/test_runner/test_utils.py index 3983e80cd7..88e0629c28 100644 --- a/tests/test_runner/test_utils.py +++ b/tests/test_runner/test_utils.py @@ -1,10 +1,18 @@ import os import random +from pkg_resources import parse_version import numpy as np import torch from mmcv.runner import set_random_seed +from mmcv.utils import TORCH_VERSION + +is_rocm_pytorch = False +if parse_version(TORCH_VERSION) >= parse_version('1.5'): + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False def test_set_random_seed(): @@ -21,7 +29,10 @@ def test_set_random_seed(): b_np_random = np.random.rand(2, 2) b_torch_random = torch.rand(2, 2) assert torch.backends.cudnn.deterministic is True - assert torch.backends.cudnn.benchmark is False + if is_rocm_pytorch: + assert torch.backends.cudnn.benchmark is True + else: + assert torch.backends.cudnn.benchmark is False assert a_random == b_random assert np.equal(a_np_random, b_np_random).all()