Skip to content

Commit

Permalink
[Feature] Porting mmcv for hip (#1022)
Browse files Browse the repository at this point in the history
* porting mmcv for hip

* add nvcc

* fix format

* fix format

* fix bug for carafe

* fix test_utils because rocm_torch not allow set torch.backends.cudnn.benchmark to false

* add LOOSEVERSION

* fix format

* fix format of version

* fix code format

* test for yaml

* fix bug for citest

* fix bug for how to get torch._version_ at setup.py
  • Loading branch information
XuanBaby authored Jul 9, 2021
1 parent db580dd commit 2dc0a21
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 8 deletions.
23 changes: 20 additions & 3 deletions mmcv/ops/csrc/carafe_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "pytorch_cuda_helper.hpp"
#endif

#ifdef HIP_DIFF
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define THREADS_PER_PIXEL 32
#define MAX_SHARED_MEMORY 49152
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
Expand All @@ -24,6 +28,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
#ifndef HIP_DIFF
/* TODO: move this to a common place */
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
Expand All @@ -34,19 +39,27 @@ template <typename scalar_t>
__device__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b;
}

#endif
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = 16; offset > 0; offset /= 2)
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef HIP_DIFF
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
return val;
}

template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = 16; offset > 0; offset /= 2)
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef HIP_DIFF
__PHALF(val) += __shfl_down(FULL_MASK, val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
#endif
return val;
}

Expand Down Expand Up @@ -302,7 +315,11 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef HIP_DIFF
__syncthreads();
#else
__syncwarp();
#endif
output_val = warpReduceSum(output_val);
if (lane_id == 0) {
const int mask_id =
Expand Down
6 changes: 6 additions & 0 deletions mmcv/ops/csrc/pytorch/info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
#endif

std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
Expand All @@ -20,6 +23,9 @@ std::string get_compiling_cuda_version() {
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not vailable");
#endif
#else
return std::string("not available");
#endif
Expand Down
13 changes: 12 additions & 1 deletion mmcv/utils/parrots_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from functools import partial
from pkg_resources import parse_version

import torch

TORCH_VERSION = torch.__version__

is_rocm_pytorch = False
if parse_version(TORCH_VERSION) >= parse_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False


def _get_cuda_home():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
if is_rocm_pytorch:
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME


Expand Down
32 changes: 29 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import glob
import os
import re
from pkg_resources import DistributionNotFound, get_distribution
from pkg_resources import DistributionNotFound, get_distribution, parse_version
from setuptools import find_packages, setup

EXT_TYPE = ''
Expand Down Expand Up @@ -220,18 +220,44 @@ def get_extensions():
define_macros = []
extra_compile_args = {'cxx': []}

if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
is_rocm_pytorch = False
if parse_version(torch.__version__) >= parse_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False

this_dir = 'mmcv/ops/csrc/'
if is_rocm_pytorch:
from torch.utils.hipify import hipify_python

hipify_python.hipify(
project_directory=this_dir,
output_directory=this_dir,
includes='mmcv/ops/csrc/*',
show_detailed=True,
is_pytorch_extension=True,
)
define_macros += [('MMCV_WITH_CUDA', None)]
define_macros += [('HIP_DIFF', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/hip/*')
extension = CUDAExtension
include_path = os.path.abspath('./mmcv/ops/csrc/hip')

elif torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
extension = CUDAExtension
include_path = os.path.abspath('./mmcv/ops/csrc')
else:
print(f'Compiling {ext_name} without CUDA')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp')
extension = CppExtension
include_path = os.path.abspath('./mmcv/ops/csrc')

include_path = os.path.abspath('./mmcv/ops/csrc')
ext_ops = extension(
name=ext_name,
sources=op_files,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_runner/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import os
import random
from pkg_resources import parse_version

import numpy as np
import torch

from mmcv.runner import set_random_seed
from mmcv.utils import TORCH_VERSION

is_rocm_pytorch = False
if parse_version(TORCH_VERSION) >= parse_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False


def test_set_random_seed():
Expand All @@ -21,7 +29,10 @@ def test_set_random_seed():
b_np_random = np.random.rand(2, 2)
b_torch_random = torch.rand(2, 2)
assert torch.backends.cudnn.deterministic is True
assert torch.backends.cudnn.benchmark is False
if is_rocm_pytorch:
assert torch.backends.cudnn.benchmark is True
else:
assert torch.backends.cudnn.benchmark is False

assert a_random == b_random
assert np.equal(a_np_random, b_np_random).all()
Expand Down

0 comments on commit 2dc0a21

Please sign in to comment.