Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Introduce the ENABLE_CUDA_RTC build option #9428

Merged
merged 1 commit into from
Jan 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ mxnet_option(USE_CPP_PACKAGE "Build C++ Package" OFF)
mxnet_option(USE_MXNET_LIB_NAMING "Use MXNet library naming conventions." ON)
mxnet_option(USE_GPROF "Compile with gprof (profiling) flag" OFF)
mxnet_option(USE_VTUNE "Enable use of Intel Amplifier XE (VTune)" OFF) # one could set VTUNE_ROOT for search path
mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" ON)
mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF)
mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF)

Expand Down Expand Up @@ -452,24 +453,35 @@ if(USE_CUDA)
string(REPLACE ";" " " NVCC_FLAGS_ARCH "${NVCC_FLAGS_ARCH}")
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS_ARCH}")
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_ARCH} -use_fast_math")
list(APPEND mxnet_LINKER_LIBS nvrtc cuda cublas cufft cusolver curand)
list(APPEND mxnet_LINKER_LIBS cublas cufft cusolver curand)
if(ENABLE_CUDA_RTC)
list(APPEND mxnet_LINKER_LIBS nvrtc cuda)
add_definitions(-DMXNET_ENABLE_CUDA_RTC=1)
endif()
list(APPEND SOURCE ${CUDA})
add_definitions(-DMXNET_USE_CUDA=1)
else()
list(APPEND CUDA_INCLUDE_DIRS ${INCLUDE_DIRECTORIES})
# define preprocessor macro so that we will not include the generated forcelink header
mshadow_cuda_compile(cuda_objs ${CUDA})
if(MSVC)
FIND_LIBRARY(CUDA_nvrtc_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32")
list(APPEND mxnet_LINKER_LIBS ${CUDA_nvrtc_LIBRARY})
set(CUDA_cuda_LIBRARY "${CUDA_nvrtc_LIBRARY}/../cuda.lib")
list(APPEND mxnet_LINKER_LIBS ${CUDA_cuda_LIBRARY})
if(ENABLE_CUDA_RTC)
FIND_LIBRARY(CUDA_nvrtc_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32")
list(APPEND mxnet_LINKER_LIBS ${CUDA_nvrtc_LIBRARY})
set(CUDA_cuda_LIBRARY "${CUDA_nvrtc_LIBRARY}/../cuda.lib")
list(APPEND mxnet_LINKER_LIBS ${CUDA_cuda_LIBRARY})
add_definitions(-DMXNET_ENABLE_CUDA_RTC=1)
endif()
FIND_LIBRARY(CUDA_cufft_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32")
list(APPEND mxnet_LINKER_LIBS "${CUDA_cufft_LIBRARY}/../cufft.lib") # For fft operator
FIND_LIBRARY(CUDA_cusolver_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32")
list(APPEND mxnet_LINKER_LIBS "${CUDA_cusolver_LIBRARY}/../cusolver.lib") # For cusolver
else(MSVC)
list(APPEND mxnet_LINKER_LIBS nvrtc cuda cufft cusolver)
list(APPEND mxnet_LINKER_LIBS cufft cusolver)
if(ENABLE_CUDA_RTC)
list(APPEND mxnet_LINKER_LIBS nvrtc cuda)
add_definitions(-DMXNET_ENABLE_CUDA_RTC=1)
endif()
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
endif()
list(APPEND SOURCE ${cuda_objs} ${CUDA})
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,11 @@ ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(PLUGIN_OBJ) $(LIB_DEP)
ifeq ($(USE_CUDA), 1)
CFLAGS += -I$(ROOTDIR)/3rdparty/cub
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) $(PLUGIN_CUOBJ)
LDFLAGS += -lcuda -lcufft -lnvrtc
LDFLAGS += -lcufft
ifeq ($(ENABLE_CUDA_RTC), 1)
LDFLAGS += -lcuda -lnvrtc
CFLAGS += -DMXNET_ENABLE_CUDA_RTC=1
endif
# Make sure to add stubs as fallback in order to be able to build
# without full CUDA install (especially if run without nvidia-docker)
LDFLAGS += -L/usr/local/cuda/lib64/stubs
Expand Down
4 changes: 2 additions & 2 deletions include/mxnet/rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#ifndef MXNET_RTC_H_
#define MXNET_RTC_H_
#include "./base.h"
#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
#include <nvrtc.h>
#include <cuda.h>

Expand Down Expand Up @@ -132,5 +132,5 @@ class CudaModule {
} // namespace rtc
} // namespace mxnet

#endif // MXNET_USE_CUDA
#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
#endif // MXNET_RTC_H_
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ USE_CUDA = 0
# USE_CUDA_PATH = /usr/local/cuda
USE_CUDA_PATH = NONE

# whether to enable CUDA runtime compilation
ENABLE_CUDA_RTC = 1

# whether use CuDNN R3 library
USE_CUDNN = 0

Expand Down
3 changes: 3 additions & 0 deletions make/osx.mk
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ USE_CUDA = 0
# USE_CUDA_PATH = /usr/local/cuda
USE_CUDA_PATH = NONE

# whether to enable CUDA runtime compilation
ENABLE_CUDA_RTC = 1

# whether use CUDNN R3 library
USE_CUDNN = 0

Expand Down
20 changes: 10 additions & 10 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1179,24 +1179,24 @@ int MXRtcCudaModuleCreate(const char* source, int num_options,
const char** options, int num_exports,
const char** exports, CudaModuleHandle *out) {
API_BEGIN();
#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
std::vector<std::string> str_opts;
for (int i = 0; i < num_options; ++i) str_opts.emplace_back(options[i]);
std::vector<std::string> str_exports;
for (int i = 0; i < num_exports; ++i) str_exports.emplace_back(exports[i]);
*out = new rtc::CudaModule(source, str_opts, str_exports);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}

int MXRtcCudaModuleFree(CudaModuleHandle handle) {
API_BEGIN();
#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
delete reinterpret_cast<rtc::CudaModule*>(handle);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
Expand All @@ -1205,7 +1205,7 @@ int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_arg
int* is_ndarray, int* is_const, int* arg_types,
CudaKernelHandle *out) {
API_BEGIN();
#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
auto module = reinterpret_cast<rtc::CudaModule*>(handle);
std::vector<rtc::CudaModule::ArgType> signature;
for (int i = 0; i < num_args; ++i) {
Expand All @@ -1216,17 +1216,17 @@ int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_arg
auto kernel = module->GetKernel(name, signature);
*out = new std::shared_ptr<rtc::CudaModule::Kernel>(kernel);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}

int MXRtcCudaKernelFree(CudaKernelHandle handle) {
API_BEGIN();
#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
delete reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
Expand All @@ -1237,7 +1237,7 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
mx_uint block_dim_y, mx_uint block_dim_z,
mx_uint shared_mem) {
API_BEGIN();
#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
auto kernel = reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
const auto& signature = (*kernel)->signature();
std::vector<dmlc::any> any_args;
Expand All @@ -1253,7 +1253,7 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
(*kernel)->Launch(Context::GPU(dev_id), any_args, grid_dim_x, grid_dim_y,
grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
Expand Down
4 changes: 2 additions & 2 deletions src/common/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "../common/cuda_utils.h"
#include "../operator/operator_common.h"

#if MXNET_USE_CUDA
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC

namespace mxnet {
namespace rtc {
Expand Down Expand Up @@ -185,4 +185,4 @@ void CudaModule::Kernel::Launch(
} // namespace rtc
} // namespace mxnet

#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC