Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Porting mmcv for hip #1022

Merged
merged 14 commits into from
Jul 9, 2021
19 changes: 18 additions & 1 deletion mmcv/ops/csrc/carafe_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "pytorch_cuda_helper.hpp"
#endif

#ifdef HIP_DIFF
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define THREADS_PER_PIXEL 32
#define MAX_SHARED_MEMORY 49152
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
Expand All @@ -24,6 +28,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
#ifndef HIP_DIFF
/* TODO: move this to a common place */
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
Expand All @@ -34,19 +39,27 @@ template <typename scalar_t>
__device__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b;
}

#endif
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = 16; offset > 0; offset /= 2)
#ifdef HIP_DIFF
val += __shfl_down(FULL_MASK, val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
return val;
}

template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = 16; offset > 0; offset /= 2)
#ifdef HIP_DIFF
__PHALF(val) += __shfl_down(FULL_MASK, val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
#endif
return val;
}

Expand Down Expand Up @@ -302,7 +315,11 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef HIP_DIFF
__syncthreads();
#else
__syncwarp();
#endif
output_val = warpReduceSum(output_val);
if (lane_id == 0) {
const int mask_id =
Expand Down
6 changes: 6 additions & 0 deletions mmcv/ops/csrc/pytorch/info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
#endif

std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
Expand All @@ -20,6 +23,9 @@ std::string get_compiling_cuda_version() {
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not vailable");
#endif
#else
return std::string("not available");
#endif
Expand Down
12 changes: 11 additions & 1 deletion mmcv/utils/parrots_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@

TORCH_VERSION = torch.__version__

is_rocm_pytorch = False
if torch.__version__ >= '1.5':
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False


def _get_cuda_home():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
if is_rocm_pytorch:
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME


Expand Down
30 changes: 28 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,44 @@ def get_extensions():
define_macros = []
extra_compile_args = {'cxx': []}

if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False

this_dir = 'mmcv/ops/csrc/'
if is_rocm_pytorch:
from torch.utils.hipify import hipify_python

hipify_python.hipify(
project_directory=this_dir,
output_directory=this_dir,
includes='mmcv/ops/csrc/*',
show_detailed=True,
is_pytorch_extension=True,
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
)
define_macros += [('MMCV_WITH_CUDA', None)]
define_macros += [('HIP_DIFF', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/hip/*')
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
extension = CUDAExtension
include_path = os.path.abspath('./mmcv/ops/csrc/hip')

elif torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
extension = CUDAExtension
include_path = os.path.abspath('./mmcv/ops/csrc')
else:
print(f'Compiling {ext_name} without CUDA')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp')
extension = CppExtension
include_path = os.path.abspath('./mmcv/ops/csrc')

include_path = os.path.abspath('./mmcv/ops/csrc')
ext_ops = extension(
name=ext_name,
sources=op_files,
Expand Down