Skip to content

Commit

Permalink
Add cudnn switch (PaddlePaddle#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored May 28, 2021
1 parent 3e3e225 commit 4278b3a
Show file tree
Hide file tree
Showing 19 changed files with 586 additions and 58 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ include(${CMAKE_BINARY_DIR}/config.cmake)
if (WITH_CUDA)
message(STATUS "Enable CUDA")
add_definitions(-DCINN_WITH_CUDA)
add_definitions(-DCINN_WITH_CUDNN)
if (WITH_CUDNN)
message(STATUS "Enable CUDNN")
add_definitions(-DCINN_WITH_CUDNN)
endif()
enable_language(CUDA)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
Expand Down
11 changes: 11 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ build_dir=$workspace/${build_dir_name}
JOBS=8

cuda_config=OFF
cudnn_config=OFF

function gpu_on {
cuda_config=ON
cudnn_config=ON
}

function cudnn_off {
cudnn_config=OFF
}

function check_style {
Expand Down Expand Up @@ -71,6 +77,7 @@ function cmake_ {
echo "set(ISL_HOME /usr/local)" >> $build_dir/config.cmake
# To enable Cuda backend, set(WITH_CUDA ON)
echo "set(WITH_CUDA $cuda_config)" >> $build_dir/config.cmake
echo "set(WITH_CUDNN $cudnn_config)" >> $build_dir/config.cmake
echo "set(WITH_MKL_CBLAS ON)" >> $build_dir/config.cmake
cd $build_dir
cmake .. -DLLVM11_DIR=${LLVM11_DIR} -DLLVM_DIR=${LLVM11_DIR}/lib/cmake/llvm -DMLIR_DIR=${LLVM11_DIR}/lib/cmake/mlir -DPUBLISH_LIBS=ON
Expand Down Expand Up @@ -176,6 +183,10 @@ function main {
gpu_on
shift
;;
cudnn_off)
cudnn_off
shift
;;
check_style)
check_style
shift
Expand Down
119 changes: 117 additions & 2 deletions cinn/backends/codegen_cuda_dev_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
#include "cinn/common/cuda_test_helper.h"
#include "cinn/common/ir_util.h"
#include "cinn/common/test_helper.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/runtime/cpu/use_extern_funcs.h"
#include "cinn/runtime/cuda/cuda_module.h"
#include "cinn/runtime/cuda/cuda_util.h"
#include "cinn/runtime/use_extern_funcs.h"
#include "cinn/utils/timer.h"

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -170,6 +172,119 @@ TEST(CodeGenCUDA2, compile_run_jit2) {
}
}

TEST(CodeGenCUDA2, test_schedule_conv2d_0) {
Expr N(1);
Expr C(128);
Expr H(28);
Expr W(256);

Target target = common::DefaultNVGPUTarget();

Placeholder<float> A("X", {N, C, H, H});
Placeholder<float> B("Y", {W, C, N, N});

auto res = hlir::pe::Conv2d_NCHW(A, B, 0, 0, 2, 2, 1, 1, "COD");

auto stages = CreateStages(res);

auto pad_data = res[0];
auto kernel = res[1];
auto conv = res[2];

stages[pad_data]->ComputeInline();
stages[kernel]->ComputeInline();

auto OL = stages[conv]->CacheWrite2("local", stages, conv);

auto tx = stages[conv]->axis(3);
auto by = stages[conv]->axis(2);
auto [tem, fi] = stages[conv]->Split(1, 2);
auto [bz, tz] = stages[conv]->Split(1, 16);

stages[conv]->Reorder({bz, by, tz, tx, fi});

stages[conv]->Bind(1, "blockIdx.z");
stages[conv]->Bind(2, "blockIdx.y");
stages[conv]->Bind(3, "threadIdx.z");
stages[conv]->Bind(4, "threadIdx.x");

stages[OL]->ComputeAt3(stages[conv], 4);

auto on = stages[OL]->axis(0);
auto obz = stages[OL]->axis(1);
auto oby = stages[OL]->axis(2);
auto otz = stages[OL]->axis(3);
auto otx = stages[OL]->axis(4);
auto ofi = stages[OL]->axis(5);
auto orc = stages[OL]->axis(6);
auto ory = stages[OL]->axis(7);
auto orx = stages[OL]->axis(8);

stages[OL]->Reorder({orc, ory, orx, on, obz, oby, otz, otx, ofi});
stages[OL]->Split(0, 8);

stages[OL]->Bind(5, "blockIdx.z");
stages[OL]->Bind(6, "blockIdx.y");
stages[OL]->Bind(7, "threadIdx.z");
stages[OL]->Bind(8, "threadIdx.x");

CodeGenCUDA_Dev codegen(target);

auto func = Lower("schedule_conv2d_0", stages, {A, B, conv}, {}, {}, nullptr, target);

Module::Builder builder("module", target);
builder.AddFunction(func);

auto source_code = codegen.Compile(builder.Build());

LOG(INFO) << "compiled schedule_conv2d_0 code:\n\n\n" << source_code;

using runtime::cuda::CUDAModule;

backends::NVRTC_Compiler compiler;

auto ptx = compiler(source_code);
CHECK(!ptx.empty());

CUDAModule cuda_module(ptx, CUDAModule::Kind::PTX);

CUDA_CALL(cudaDeviceSynchronize());

CUdeviceptr Ad, Bd, Cd;
cuMemAlloc(&Ad, 128 * 28 * 28 * sizeof(float));
cuMemAlloc(&Bd, 256 * 128 * sizeof(float));
cuMemAlloc(&Cd, 256 * 14 * 14 * sizeof(float));

std::vector<float> host_data1(128 * 28 * 28, 0);
std::vector<float> host_data2(256 * 128, 0);
std::vector<float> host_data3(256 * 14 * 14, 0);
for (float& v : host_data1) v = static_cast<float>(rand()) / INT_MAX; // NOLINT
for (float& v : host_data2) v = static_cast<float>(rand()) / INT_MAX; // NOLINT

CUDA_CALL(cudaMemcpy(
reinterpret_cast<void*>(Ad), host_data1.data(), 128 * 28 * 28 * sizeof(float), cudaMemcpyHostToDevice));
CUDA_CALL(
cudaMemcpy(reinterpret_cast<void*>(Bd), host_data2.data(), 256 * 128 * sizeof(float), cudaMemcpyHostToDevice));

// launch the kernel

void* args[] = {&Ad, &Bd, &Cd};

dim3 grid(1, 14, 8);
dim3 block(14, 1, 16);
int repeat = 100;

utils::Timer time1;
time1.Start();
for (int i = 0; i < repeat; i++) {
cuda_module.LaunchKernel(0, "schedule_conv2d_0", grid, block, args);
}
LOG(INFO) << "Conv2d op with schedule repeats " << repeat
<< " times, average time cost is : " << time1.Stop() / float(repeat) << "ms. ";
CUDA_CALL(cudaMemcpy(
host_data3.data(), reinterpret_cast<void*>(Cd), 256 * 14 * 14 * sizeof(float), cudaMemcpyDeviceToHost));
}

TEST(CodeGenCUDA, compile_run_jit) {
Expr M(100);
Expr N(200);
Expand Down Expand Up @@ -1954,7 +2069,7 @@ void fn5(const float* __restrict__ A, const float* __restrict__ B, float* __rest
TestElementwiseAddPrecisionBasic(
builder.Build(), "fn5", M, N, [](float a, float b) { return std::tanh(a) + std::cos(b); });
}

#ifdef CINN_WITH_CUDNN
TEST(Cudnn, external_function_cudnn) {
Context::Global().ResetNameId();

Expand Down Expand Up @@ -2017,6 +2132,6 @@ TEST(Cudnn, external_function_cudnn3) {

runtime::cuda::cinn_gpu_cudnn_softmax({2, 1000, -1}, dev_bufs[0], dev_bufs[1]);
}

#endif
} // namespace backends
} // namespace cinn
13 changes: 13 additions & 0 deletions cinn/frontend/paddle_model_to_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ void PaddleModelToProgram::TransposeVar(const std::string& name) {
} else if (target_.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
// To use cublas mul api, there is no need to transpose data.
#ifndef CINN_WITH_CUDNN
std::vector<float> data(tensor->shape().numel());
CUDA_CALL(cudaMemcpy(data.data(),
reinterpret_cast<void*>(tensor->mutable_data<float>(target_)),
tensor->shape().numel() * sizeof(float),
cudaMemcpyDeviceToHost));
CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check.";
TransposeData(data.data(), tensor->shape().data()[0], tensor->shape().data()[1]);
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(tensor->mutable_data<float>(target_)),
data.data(),
tensor->shape().numel() * sizeof(float),
cudaMemcpyHostToDevice));
#endif
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif
Expand Down
4 changes: 4 additions & 0 deletions cinn/hlir/framework/cuda_graph_compiler_test.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ std::vector<float> test_mul(const std::vector<float>& A, const std::vector<float
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
for (int k = 0; k < K; k++) {
#ifdef CINN_WITH_CUDNN
C_target[i * N + j] += A[i * K + k] * B[k * N + j];
#else
C_target[i * N + j] += A[i * K + k] * B[j * N + k];
#endif
}
}
}
Expand Down
16 changes: 9 additions & 7 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,15 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
stages[weights_dilation.as_tensor_ref()]->ComputeInline();
}
if (target.arch == Target::Arch::NVGPU) {
Expr Out = arg_pack[2];
Expr Out = arg_pack[2];
Expr input_pad = arg_pack[0];
Expr weights_dilation = arg_pack[1];
ir::Tensor out_t = Out.as_tensor_ref();
ir::Tensor input_t = input_pad.as_tensor_ref();
ir::Tensor weights_t = weights_dilation.as_tensor_ref();
CHECK(Out.as_tensor());
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "blockIdx.y");
stages[Out.as_tensor_ref()]->Bind(2, "blockIdx.z");
stages[Out.as_tensor_ref()]->Bind(3, "threadIdx.x");
pe::CudaScheduleConv(stages, input_t, weights_t, out_t, target);
arg_pack[2] = Expr(out_t);
}
if (arg_pack.size() == 4UL) {
*ret = CINNValuePack{{arg_pack[arg_pack.size() - 2], CINNValue(stages)}};
Expand Down Expand Up @@ -1387,8 +1390,7 @@ CINN_REGISTER_HELPER(nn_ops) {
#ifdef CINN_WITH_CUDNN
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
#else
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern",
cinn::hlir::framework::OpPatternKind::kOutEWiseFusable)
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
#endif
.set_support_level(4);

Expand Down
91 changes: 82 additions & 9 deletions cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,40 @@
#include <numeric>
#include <utility>

#include "cinn/optim/ir_simplify.h"
#include "cinn/poly/isl_utils.h"

namespace cinn {
namespace hlir {
namespace pe {

int GetInnerSplitter(int origin, int other_axis) {
int two_exp = 1;
while (origin % two_exp == 0) {
two_exp *= 2;
}
two_exp = two_exp / 2;
int a = SplitEven(two_exp);
int b = two_exp / a;
while (a * other_axis >= 1024 || b * other_axis >= 1024) {
two_exp = two_exp / 2;
a = SplitEven(two_exp);
b = two_exp / a;
}
if (origin == two_exp) {
return 2;
}
return origin / two_exp;
}

int SplitEven(int origin) {
int res = 1;
while (origin % res == 0 && res * res < origin) {
res *= 2;
}
res = res / 2;
return res;
}

int GetBasicFactor(const Type &type, const common::Target &target) {
int target_native_vector_bits = target.get_target_bits() * 8;
int type_bits = type.bits();
Expand Down Expand Up @@ -225,15 +253,60 @@ void MulScheduleCPU(poly::StageMap stages,
}

void CudaScheduleConv(poly::StageMap stages,
ir::Tensor input_pad,
ir::Tensor kernel_dilation,
ir::Tensor output,
ir::Tensor &input_pad,
ir::Tensor &kernel_dilation,
ir::Tensor &output,
const common::Target &target) {
int num_thread = target.max_num_threads();
stages[output]->Fuse(0, 1);
auto [Block_x, Thread_x] = stages[output]->Split(0, num_thread);
stages[output]->Bind(0, "blockIdx.x");
stages[output]->Bind(1, "threadIdx.x");
int n = output->shape[0].as_int32();
int c = output->shape[1].as_int32();
optim::Simplify(&(output->shape[2]));
int h = output->shape[2].as_int32();
optim::Simplify(&(output->shape[3]));
int w = output->shape[3].as_int32();
int rc = kernel_dilation->shape[1].as_int32();
int ry = kernel_dilation->shape[2].as_int32();
int rx = kernel_dilation->shape[3].as_int32();

int f_inner = GetInnerSplitter(c, h);
int block_z = SplitEven(c / f_inner);
int thread_z = c / f_inner / block_z;

int rc_factor = SplitEven(rc);

auto OL = stages[output]->CacheWrite2("local", stages, output);

auto tx = stages[output]->axis(3);
auto by = stages[output]->axis(2);
auto [tem, fi] = stages[output]->Split(1, f_inner);
auto [bz, tz] = stages[output]->Split(1, thread_z);
stages[output]->Reorder({bz, by, tz, tx, fi});
stages[output]->Bind(1, "blockIdx.z");
stages[output]->Bind(2, "blockIdx.y");
stages[output]->Bind(3, "threadIdx.z");
stages[output]->Bind(4, "threadIdx.x");
stages[OL]->ComputeAt3(stages[output], 4);
auto on = stages[OL]->axis(0);
auto obz = stages[OL]->axis(1);
auto oby = stages[OL]->axis(2);
auto otz = stages[OL]->axis(3);
auto otx = stages[OL]->axis(4);
auto ofi = stages[OL]->axis(5);
auto orc = stages[OL]->axis(6);
auto ory = stages[OL]->axis(7);
auto orx = stages[OL]->axis(8);
stages[OL]->Reorder({orc, ory, orx, on, obz, oby, otz, otx, ofi});
if (rc_factor > 1) {
stages[OL]->Split(0, rc_factor);
stages[OL]->Bind(5, "blockIdx.z");
stages[OL]->Bind(6, "blockIdx.y");
stages[OL]->Bind(7, "threadIdx.z");
stages[OL]->Bind(8, "threadIdx.x");
} else {
stages[OL]->Bind(4, "blockIdx.z");
stages[OL]->Bind(5, "blockIdx.y");
stages[OL]->Bind(6, "threadIdx.z");
stages[OL]->Bind(7, "threadIdx.x");
}

return;
}
Expand Down
9 changes: 6 additions & 3 deletions cinn/hlir/pe/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
namespace cinn {
namespace hlir {
namespace pe {
int GetInnerSplitter(int origin, int other_axis);

int SplitEven(int origin);

int GetBasicFactor(const Type &type, const common::Target &target);

Expand All @@ -35,9 +38,9 @@ void CudaScheduleMul(poly::StageMap stages,
const common::Target &target);

void CudaScheduleConv(poly::StageMap stages,
ir::Tensor input_pad,
ir::Tensor kernel_dilation,
ir::Tensor output,
ir::Tensor &input_pad,
ir::Tensor &kernel_dilation,
ir::Tensor &output,
const common::Target &target);

void CudaScheduleInjective(poly::Stage *stage, const std::vector<int> &output_shape, const common::Target &target);
Expand Down
Loading

0 comments on commit 4278b3a

Please sign in to comment.