From 4278b3a4522b4fd2745fcae7c3c0c62135f95958 Mon Sep 17 00:00:00 2001 From: haozech Date: Fri, 28 May 2021 13:18:26 +0800 Subject: [PATCH] Add cudnn switch (#389) --- CMakeLists.txt | 5 +- build.sh | 11 + cinn/backends/codegen_cuda_dev_test.cc | 119 ++++++++- cinn/frontend/paddle_model_to_program.cc | 13 + .../framework/cuda_graph_compiler_test.cc | 4 + cinn/hlir/op/nn.cc | 16 +- cinn/hlir/pe/schedule.cc | 91 ++++++- cinn/hlir/pe/schedule.h | 9 +- cinn/ir/tensor.cc | 2 + cinn/optim/replace_var_with_expr.cc | 0 cinn/poly/isl_utils.cc | 44 +++- cinn/poly/isl_utils.h | 26 +- cinn/poly/stage.cc | 43 +++- cinn/poly/stage.h | 6 +- cinn/runtime/cuda/cuda_util.cc | 0 python/tests/fake_model/resnet_model.py | 4 +- python/tests/test_frontend.py | 13 +- python/tests/test_op_benchmark.py | 236 +++++++++++++++++- python/tests/test_resnet.py | 2 +- 19 files changed, 586 insertions(+), 58 deletions(-) mode change 100755 => 100644 cinn/hlir/framework/cuda_graph_compiler_test.cc mode change 100644 => 100755 cinn/optim/replace_var_with_expr.cc mode change 100644 => 100755 cinn/poly/stage.cc mode change 100644 => 100755 cinn/runtime/cuda/cuda_util.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ce9f83e903b11..cd3c5010cdae95 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,10 @@ include(${CMAKE_BINARY_DIR}/config.cmake) if (WITH_CUDA) message(STATUS "Enable CUDA") add_definitions(-DCINN_WITH_CUDA) - add_definitions(-DCINN_WITH_CUDNN) + if (WITH_CUDNN) + message(STATUS "Enable CUDNN") + add_definitions(-DCINN_WITH_CUDNN) + endif() enable_language(CUDA) find_package(CUDA REQUIRED) include_directories(${CUDA_INCLUDE_DIRS}) diff --git a/build.sh b/build.sh index 9c57fa6860d9d9..d2ee52058d5388 100755 --- a/build.sh +++ b/build.sh @@ -8,9 +8,15 @@ build_dir=$workspace/${build_dir_name} JOBS=8 cuda_config=OFF +cudnn_config=OFF function gpu_on { cuda_config=ON + cudnn_config=ON +} + +function cudnn_off { + cudnn_config=OFF } function check_style { @@ -71,6 +77,7 @@ function cmake_ { echo "set(ISL_HOME /usr/local)" >> $build_dir/config.cmake # To enable Cuda backend, set(WITH_CUDA ON) echo "set(WITH_CUDA $cuda_config)" >> $build_dir/config.cmake + echo "set(WITH_CUDNN $cudnn_config)" >> $build_dir/config.cmake echo "set(WITH_MKL_CBLAS ON)" >> $build_dir/config.cmake cd $build_dir cmake .. -DLLVM11_DIR=${LLVM11_DIR} -DLLVM_DIR=${LLVM11_DIR}/lib/cmake/llvm -DMLIR_DIR=${LLVM11_DIR}/lib/cmake/mlir -DPUBLISH_LIBS=ON @@ -176,6 +183,10 @@ function main { gpu_on shift ;; + cudnn_off) + cudnn_off + shift + ;; check_style) check_style shift diff --git a/cinn/backends/codegen_cuda_dev_test.cc b/cinn/backends/codegen_cuda_dev_test.cc index d7fca7e0a1b5c6..b087a50b83c0cb 100755 --- a/cinn/backends/codegen_cuda_dev_test.cc +++ b/cinn/backends/codegen_cuda_dev_test.cc @@ -17,11 +17,13 @@ #include "cinn/common/cuda_test_helper.h" #include "cinn/common/ir_util.h" #include "cinn/common/test_helper.h" +#include "cinn/hlir/pe/nn.h" #include "cinn/ir/ir_printer.h" #include "cinn/runtime/cpu/use_extern_funcs.h" #include "cinn/runtime/cuda/cuda_module.h" #include "cinn/runtime/cuda/cuda_util.h" #include "cinn/runtime/use_extern_funcs.h" +#include "cinn/utils/timer.h" namespace cinn { namespace backends { @@ -170,6 +172,119 @@ TEST(CodeGenCUDA2, compile_run_jit2) { } } +TEST(CodeGenCUDA2, test_schedule_conv2d_0) { + Expr N(1); + Expr C(128); + Expr H(28); + Expr W(256); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("X", {N, C, H, H}); + Placeholder B("Y", {W, C, N, N}); + + auto res = hlir::pe::Conv2d_NCHW(A, B, 0, 0, 2, 2, 1, 1, "COD"); + + auto stages = CreateStages(res); + + auto pad_data = res[0]; + auto kernel = res[1]; + auto conv = res[2]; + + stages[pad_data]->ComputeInline(); + stages[kernel]->ComputeInline(); + + auto OL = stages[conv]->CacheWrite2("local", stages, conv); + + auto tx = stages[conv]->axis(3); + auto by = stages[conv]->axis(2); + auto [tem, fi] = stages[conv]->Split(1, 2); + auto [bz, tz] = stages[conv]->Split(1, 16); + + stages[conv]->Reorder({bz, by, tz, tx, fi}); + + stages[conv]->Bind(1, "blockIdx.z"); + stages[conv]->Bind(2, "blockIdx.y"); + stages[conv]->Bind(3, "threadIdx.z"); + stages[conv]->Bind(4, "threadIdx.x"); + + stages[OL]->ComputeAt3(stages[conv], 4); + + auto on = stages[OL]->axis(0); + auto obz = stages[OL]->axis(1); + auto oby = stages[OL]->axis(2); + auto otz = stages[OL]->axis(3); + auto otx = stages[OL]->axis(4); + auto ofi = stages[OL]->axis(5); + auto orc = stages[OL]->axis(6); + auto ory = stages[OL]->axis(7); + auto orx = stages[OL]->axis(8); + + stages[OL]->Reorder({orc, ory, orx, on, obz, oby, otz, otx, ofi}); + stages[OL]->Split(0, 8); + + stages[OL]->Bind(5, "blockIdx.z"); + stages[OL]->Bind(6, "blockIdx.y"); + stages[OL]->Bind(7, "threadIdx.z"); + stages[OL]->Bind(8, "threadIdx.x"); + + CodeGenCUDA_Dev codegen(target); + + auto func = Lower("schedule_conv2d_0", stages, {A, B, conv}, {}, {}, nullptr, target); + + Module::Builder builder("module", target); + builder.AddFunction(func); + + auto source_code = codegen.Compile(builder.Build()); + + LOG(INFO) << "compiled schedule_conv2d_0 code:\n\n\n" << source_code; + + using runtime::cuda::CUDAModule; + + backends::NVRTC_Compiler compiler; + + auto ptx = compiler(source_code); + CHECK(!ptx.empty()); + + CUDAModule cuda_module(ptx, CUDAModule::Kind::PTX); + + CUDA_CALL(cudaDeviceSynchronize()); + + CUdeviceptr Ad, Bd, Cd; + cuMemAlloc(&Ad, 128 * 28 * 28 * sizeof(float)); + cuMemAlloc(&Bd, 256 * 128 * sizeof(float)); + cuMemAlloc(&Cd, 256 * 14 * 14 * sizeof(float)); + + std::vector host_data1(128 * 28 * 28, 0); + std::vector host_data2(256 * 128, 0); + std::vector host_data3(256 * 14 * 14, 0); + for (float& v : host_data1) v = static_cast(rand()) / INT_MAX; // NOLINT + for (float& v : host_data2) v = static_cast(rand()) / INT_MAX; // NOLINT + + CUDA_CALL(cudaMemcpy( + reinterpret_cast(Ad), host_data1.data(), 128 * 28 * 28 * sizeof(float), cudaMemcpyHostToDevice)); + CUDA_CALL( + cudaMemcpy(reinterpret_cast(Bd), host_data2.data(), 256 * 128 * sizeof(float), cudaMemcpyHostToDevice)); + + // launch the kernel + + void* args[] = {&Ad, &Bd, &Cd}; + + dim3 grid(1, 14, 8); + dim3 block(14, 1, 16); + int repeat = 100; + + utils::Timer time1; + time1.Start(); + for (int i = 0; i < repeat; i++) { + cuda_module.LaunchKernel(0, "schedule_conv2d_0", grid, block, args); + } + LOG(INFO) << "Conv2d op with schedule repeats " << repeat + << " times, average time cost is : " << time1.Stop() / float(repeat) << "ms. "; + CUDA_CALL(cudaMemcpy( + host_data3.data(), reinterpret_cast(Cd), 256 * 14 * 14 * sizeof(float), cudaMemcpyDeviceToHost)); +} + TEST(CodeGenCUDA, compile_run_jit) { Expr M(100); Expr N(200); @@ -1954,7 +2069,7 @@ void fn5(const float* __restrict__ A, const float* __restrict__ B, float* __rest TestElementwiseAddPrecisionBasic( builder.Build(), "fn5", M, N, [](float a, float b) { return std::tanh(a) + std::cos(b); }); } - +#ifdef CINN_WITH_CUDNN TEST(Cudnn, external_function_cudnn) { Context::Global().ResetNameId(); @@ -2017,6 +2132,6 @@ TEST(Cudnn, external_function_cudnn3) { runtime::cuda::cinn_gpu_cudnn_softmax({2, 1000, -1}, dev_bufs[0], dev_bufs[1]); } - +#endif } // namespace backends } // namespace cinn diff --git a/cinn/frontend/paddle_model_to_program.cc b/cinn/frontend/paddle_model_to_program.cc index d0fc305fc84c35..e315ae2e15c663 100755 --- a/cinn/frontend/paddle_model_to_program.cc +++ b/cinn/frontend/paddle_model_to_program.cc @@ -420,6 +420,19 @@ void PaddleModelToProgram::TransposeVar(const std::string& name) { } else if (target_.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA // To use cublas mul api, there is no need to transpose data. +#ifndef CINN_WITH_CUDNN + std::vector data(tensor->shape().numel()); + CUDA_CALL(cudaMemcpy(data.data(), + reinterpret_cast(tensor->mutable_data(target_)), + tensor->shape().numel() * sizeof(float), + cudaMemcpyDeviceToHost)); + CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check."; + TransposeData(data.data(), tensor->shape().data()[0], tensor->shape().data()[1]); + CUDA_CALL(cudaMemcpy(reinterpret_cast(tensor->mutable_data(target_)), + data.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); +#endif #else LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; #endif diff --git a/cinn/hlir/framework/cuda_graph_compiler_test.cc b/cinn/hlir/framework/cuda_graph_compiler_test.cc old mode 100755 new mode 100644 index 53d48b5855dec8..d130fbe391d570 --- a/cinn/hlir/framework/cuda_graph_compiler_test.cc +++ b/cinn/hlir/framework/cuda_graph_compiler_test.cc @@ -43,7 +43,11 @@ std::vector test_mul(const std::vector& A, const std::vector StrategyForConv2d(const framework::NodeAttr &attrs, stages[weights_dilation.as_tensor_ref()]->ComputeInline(); } if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[2]; + Expr Out = arg_pack[2]; + Expr input_pad = arg_pack[0]; + Expr weights_dilation = arg_pack[1]; + ir::Tensor out_t = Out.as_tensor_ref(); + ir::Tensor input_t = input_pad.as_tensor_ref(); + ir::Tensor weights_t = weights_dilation.as_tensor_ref(); CHECK(Out.as_tensor()); - stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); - stages[Out.as_tensor_ref()]->Bind(1, "blockIdx.y"); - stages[Out.as_tensor_ref()]->Bind(2, "blockIdx.z"); - stages[Out.as_tensor_ref()]->Bind(3, "threadIdx.x"); + pe::CudaScheduleConv(stages, input_t, weights_t, out_t, target); + arg_pack[2] = Expr(out_t); } if (arg_pack.size() == 4UL) { *ret = CINNValuePack{{arg_pack[arg_pack.size() - 2], CINNValue(stages)}}; @@ -1387,8 +1390,7 @@ CINN_REGISTER_HELPER(nn_ops) { #ifdef CINN_WITH_CUDNN .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) #else - .set_attr("OpPattern", - cinn::hlir::framework::OpPatternKind::kOutEWiseFusable) + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) #endif .set_support_level(4); diff --git a/cinn/hlir/pe/schedule.cc b/cinn/hlir/pe/schedule.cc index e153d4f52d98a9..d3a1a1ae2258d1 100755 --- a/cinn/hlir/pe/schedule.cc +++ b/cinn/hlir/pe/schedule.cc @@ -6,12 +6,40 @@ #include #include +#include "cinn/optim/ir_simplify.h" #include "cinn/poly/isl_utils.h" - namespace cinn { namespace hlir { namespace pe { +int GetInnerSplitter(int origin, int other_axis) { + int two_exp = 1; + while (origin % two_exp == 0) { + two_exp *= 2; + } + two_exp = two_exp / 2; + int a = SplitEven(two_exp); + int b = two_exp / a; + while (a * other_axis >= 1024 || b * other_axis >= 1024) { + two_exp = two_exp / 2; + a = SplitEven(two_exp); + b = two_exp / a; + } + if (origin == two_exp) { + return 2; + } + return origin / two_exp; +} + +int SplitEven(int origin) { + int res = 1; + while (origin % res == 0 && res * res < origin) { + res *= 2; + } + res = res / 2; + return res; +} + int GetBasicFactor(const Type &type, const common::Target &target) { int target_native_vector_bits = target.get_target_bits() * 8; int type_bits = type.bits(); @@ -225,15 +253,60 @@ void MulScheduleCPU(poly::StageMap stages, } void CudaScheduleConv(poly::StageMap stages, - ir::Tensor input_pad, - ir::Tensor kernel_dilation, - ir::Tensor output, + ir::Tensor &input_pad, + ir::Tensor &kernel_dilation, + ir::Tensor &output, const common::Target &target) { - int num_thread = target.max_num_threads(); - stages[output]->Fuse(0, 1); - auto [Block_x, Thread_x] = stages[output]->Split(0, num_thread); - stages[output]->Bind(0, "blockIdx.x"); - stages[output]->Bind(1, "threadIdx.x"); + int n = output->shape[0].as_int32(); + int c = output->shape[1].as_int32(); + optim::Simplify(&(output->shape[2])); + int h = output->shape[2].as_int32(); + optim::Simplify(&(output->shape[3])); + int w = output->shape[3].as_int32(); + int rc = kernel_dilation->shape[1].as_int32(); + int ry = kernel_dilation->shape[2].as_int32(); + int rx = kernel_dilation->shape[3].as_int32(); + + int f_inner = GetInnerSplitter(c, h); + int block_z = SplitEven(c / f_inner); + int thread_z = c / f_inner / block_z; + + int rc_factor = SplitEven(rc); + + auto OL = stages[output]->CacheWrite2("local", stages, output); + + auto tx = stages[output]->axis(3); + auto by = stages[output]->axis(2); + auto [tem, fi] = stages[output]->Split(1, f_inner); + auto [bz, tz] = stages[output]->Split(1, thread_z); + stages[output]->Reorder({bz, by, tz, tx, fi}); + stages[output]->Bind(1, "blockIdx.z"); + stages[output]->Bind(2, "blockIdx.y"); + stages[output]->Bind(3, "threadIdx.z"); + stages[output]->Bind(4, "threadIdx.x"); + stages[OL]->ComputeAt3(stages[output], 4); + auto on = stages[OL]->axis(0); + auto obz = stages[OL]->axis(1); + auto oby = stages[OL]->axis(2); + auto otz = stages[OL]->axis(3); + auto otx = stages[OL]->axis(4); + auto ofi = stages[OL]->axis(5); + auto orc = stages[OL]->axis(6); + auto ory = stages[OL]->axis(7); + auto orx = stages[OL]->axis(8); + stages[OL]->Reorder({orc, ory, orx, on, obz, oby, otz, otx, ofi}); + if (rc_factor > 1) { + stages[OL]->Split(0, rc_factor); + stages[OL]->Bind(5, "blockIdx.z"); + stages[OL]->Bind(6, "blockIdx.y"); + stages[OL]->Bind(7, "threadIdx.z"); + stages[OL]->Bind(8, "threadIdx.x"); + } else { + stages[OL]->Bind(4, "blockIdx.z"); + stages[OL]->Bind(5, "blockIdx.y"); + stages[OL]->Bind(6, "threadIdx.z"); + stages[OL]->Bind(7, "threadIdx.x"); + } return; } diff --git a/cinn/hlir/pe/schedule.h b/cinn/hlir/pe/schedule.h index f629f4f00d676a..ff25e06d427a19 100644 --- a/cinn/hlir/pe/schedule.h +++ b/cinn/hlir/pe/schedule.h @@ -10,6 +10,9 @@ namespace cinn { namespace hlir { namespace pe { +int GetInnerSplitter(int origin, int other_axis); + +int SplitEven(int origin); int GetBasicFactor(const Type &type, const common::Target &target); @@ -35,9 +38,9 @@ void CudaScheduleMul(poly::StageMap stages, const common::Target &target); void CudaScheduleConv(poly::StageMap stages, - ir::Tensor input_pad, - ir::Tensor kernel_dilation, - ir::Tensor output, + ir::Tensor &input_pad, + ir::Tensor &kernel_dilation, + ir::Tensor &output, const common::Target &target); void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_shape, const common::Target &target); diff --git a/cinn/ir/tensor.cc b/cinn/ir/tensor.cc index 4caad086bd380e..4fc2957cf8ca57 100755 --- a/cinn/ir/tensor.cc +++ b/cinn/ir/tensor.cc @@ -267,6 +267,7 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) init_tensor->new_indices = this->new_indices; stages[this]->CtrlDepend(init_tensor); stages[init_tensor]->ShareBufferWith(stages[this]); + init_tensor->shape = shape; return init_tensor; } //! When reduce axies are reordered to front, ComputeAt is illegal. @@ -288,6 +289,7 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) init_tensor->new_indices = this->new_indices; stages[this]->CtrlDepend(init_tensor); stages[init_tensor]->ShareBufferWith(stages[this]); + init_tensor->shape = shape; return init_tensor; } diff --git a/cinn/optim/replace_var_with_expr.cc b/cinn/optim/replace_var_with_expr.cc old mode 100644 new mode 100755 diff --git a/cinn/poly/isl_utils.cc b/cinn/poly/isl_utils.cc index cbaf8326013e71..118687471607dc 100644 --- a/cinn/poly/isl_utils.cc +++ b/cinn/poly/isl_utils.cc @@ -320,7 +320,7 @@ isl::set isl_set_dim_name_if_null(isl_set *set, std::function &dim_in_names) { +isl::map RemoveAxiesByInputNames(const isl::map &x, const std::vector &dim_in_names) { std::string map_str = isl_map_to_str(x.get()); isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); @@ -338,12 +338,50 @@ isl::map RemoveAxiesByNames(const isl::map &x, const std::vector &d return temp_transform; } -std::vector GetRelatedAxies(const isl::map &x, const std::string &dim_out_name) { +isl::map RemoveAxiesByOutputNames(const isl::map &x, const std::vector &dim_out_names) { std::string map_str = isl_map_to_str(x.get()); isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); + if (dim_out_names.empty()) return temp_transform; auto dim_in_names = isl_get_dim_names(temp_transform, isl_dim_in); - temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, dim_out_name.c_str())); + for (auto &i : dim_out_names) { + temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str())); + } + std::string deleted_map = isl_map_to_str(temp_transform.get()); + for (auto &i : dim_in_names) { + if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) { + temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str())); + } + } + return temp_transform; +} + +std::vector GetRelatedOutputAxies(const isl::map &x, const std::vector &dim_in_names) { + std::string map_str = isl_map_to_str(x.get()); + isl::ctx this_ctx = x.ctx(); + isl::map temp_transform(this_ctx, map_str); + auto dim_out_names = isl_get_dim_names(temp_transform, isl_dim_out); + for (auto &i : dim_in_names) { + temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str())); + } + std::string deleted_map = isl_map_to_str(temp_transform.get()); + std::vector res; + for (auto &i : dim_out_names) { + if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) { + res.push_back(i); + } + } + return res; +} + +std::vector GetRelatedInputAxies(const isl::map &x, const std::vector &dim_out_names) { + std::string map_str = isl_map_to_str(x.get()); + isl::ctx this_ctx = x.ctx(); + isl::map temp_transform(this_ctx, map_str); + auto dim_in_names = isl_get_dim_names(temp_transform, isl_dim_in); + for (auto &i : dim_out_names) { + temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str())); + } std::string deleted_map = isl_map_to_str(temp_transform.get()); std::vector res; for (auto &i : dim_in_names) { diff --git a/cinn/poly/isl_utils.h b/cinn/poly/isl_utils.h index e8a78fdd477ca0..991d71619e5f34 100755 --- a/cinn/poly/isl_utils.h +++ b/cinn/poly/isl_utils.h @@ -73,16 +73,34 @@ isl::set SetGetDims(isl::set set, const std::vector& dims); * @param dim_in_names The names of input dims to remove. * @return The edited map. */ -isl::map RemoveAxiesByNames(const isl::map& x, const std::vector& dim_in_names); +isl::map RemoveAxiesByInputNames(const isl::map& x, const std::vector& dim_in_names); /** - * Given an isl::map and the name of an output dim, + * Given an isl::map and a vector of names of dim_out, + * remove the output dims in vector and related input dims. + * @param x The map to edit. + * @param dim_in_names The names of output dims to remove. + * @return The edited map. + */ +isl::map RemoveAxiesByOutputNames(const isl::map& x, const std::vector& dim_out_names); + +/** + * Given an isl::map and a vector of names of dim_out, * get the names of related input dims. * @param x The input map. - * @param dim_in_names The name of an output dim. + * @param dim_out_names The names of output dims. * @return The vector of names of related input dims. */ -std::vector GetRelatedAxies(const isl::map& x, const std::string& dim_out_name); +std::vector GetRelatedInputAxies(const isl::map& x, const std::vector& dim_out_names); + +/** + * Given an isl::map and a vector of names of dim_in, + * get the names of related output dims. + * @param x The input map. + * @param dim_in_names The names of input dims. + * @return The vector of names of related output dims. + */ +std::vector GetRelatedOutputAxies(const isl::map& x, const std::vector& dim_in_names); } // namespace poly } // namespace cinn diff --git a/cinn/poly/stage.cc b/cinn/poly/stage.cc old mode 100644 new mode 100755 index 0b59d8ef20cdc3..d00e0c93539e32 --- a/cinn/poly/stage.cc +++ b/cinn/poly/stage.cc @@ -347,7 +347,7 @@ void Stage::EditTempTensor(Stage *other, int level) { } } // Iterators of loop within level will be erased. - auto related_dim_in = GetRelatedAxies(this->transform(), transform_domain_names[i]); + auto related_dim_in = GetRelatedInputAxies(this->transform(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } @@ -358,23 +358,23 @@ void Stage::EditTempTensor(Stage *other, int level) { if (bind_info.count(i) != 0) { if (bind_info[i].for_type == ir::ForType::GPUBlock && (this->scope() == ScopeKind::kShared || this->scope() == ScopeKind::kLocal)) { - auto related_dim_in = GetRelatedAxies(this->transform(), transform_domain_names[i]); + auto related_dim_in = GetRelatedInputAxies(this->transform(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } } else if (bind_info[i].for_type == ir::ForType::GPUThread && (this->scope() == ScopeKind::kLocal)) { - auto related_dim_in = GetRelatedAxies(this->transform(), transform_domain_names[i]); + auto related_dim_in = GetRelatedInputAxies(this->transform(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } } else { - auto related_dim_in = GetRelatedAxies(this->transform(), transform_domain_names[i]); + auto related_dim_in = GetRelatedInputAxies(this->transform(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { undo_erase_var.insert(j); } } } else { - auto related_dim_in = GetRelatedAxies(this->transform(), transform_domain_names[i]); + auto related_dim_in = GetRelatedInputAxies(this->transform(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { undo_erase_var.insert(j); } @@ -417,18 +417,29 @@ void Stage::EditTempTensor(Stage *other, int level) { optim::Simplify(&i); } // Set new shape. + VLOG(3) << "Tensor is : " << this->tensor()->name; + for (auto &i : new_shape) { + VLOG(3) << "In Temp Buffer, shape is: " << utils::GetStreamCnt(i); + } this->tensor()->shape = new_shape; CHECK(this->tensor()->buffer.defined()); this->tensor()->buffer->shape = new_shape; return; } -void Stage::ComputeAt2(Stage *other, int level, ComputeAtKind kind) { +void Stage::ComputeAt2(Stage *other, int level) { // TODO(Superjomn) Check there are data dependency between `self` and `other`, or the `ComputeAt` is meaningless. this->ChangeDomain(other, level); this->CopyTransform(other, level); this->ChangeIndex(other); CHECK(tensor_); + other->CtrlDepend(ir::Tensor(tensor())); + if (this->tensor()->buffer.defined()) { + std::string t_name = this->tensor()->buffer->name; + if (utils::Endswith(t_name, "_read_cache") || utils::Endswith(t_name, "_cache_write_out")) { + EditTempTensor(other, level); + } + } ComputeAtRelation relation; relation.stage = other; relation.level = level; @@ -436,12 +447,21 @@ void Stage::ComputeAt2(Stage *other, int level, ComputeAtKind kind) { CHECK(relation.IsCompatible(this)); compute_ats_[other->id()] = relation; +} + +void Stage::ComputeAt3(Stage *other, int level) { + this->ChangeDomain(other, level); + this->CopyTransform(other, level); + this->ChangeIndex(other); + CHECK(tensor_); + other->CtrlDepend(ir::Tensor(tensor())); if (this->tensor()->buffer.defined()) { - std::string t_name = this->tensor()->name; + std::string t_name = this->tensor()->buffer->name; if (utils::Endswith(t_name, "_read_cache") || utils::Endswith(t_name, "_cache_write_out")) { EditTempTensor(other, level); } } + return; } void Stage::ComputeAt(Stage *other, int level, Stage::ComputeAtKind kind, const std::string &cached_tensor_name) { @@ -605,6 +625,12 @@ std::vector Stage::compute_ats() const { return xs; } +void Stage::ShowISL() { + LOG(INFO) << "Tensor " << tensor()->name << " domain is: " << isl_set_to_str(domain().get()); + LOG(INFO) << "transformed_domain is: " << isl_set_to_str(transformed_domain().get()); + LOG(INFO) << "transform is: " << isl_map_to_str(transform().get()); +} + bool ComputeAtRelation::IsCompatible(Stage *self) { CHECK_GE(level, 0); CHECK(!self->domain().is_null()); @@ -1013,7 +1039,7 @@ void Stage::AddForloopInfo(int level, const StageForloopInfo &info) { } void Stage::CopyTransform(Stage *other, int level) { - auto target_transform = RemoveAxiesByNames(other->transform(), other->origin_reduce_axis_names()); + auto target_transform = RemoveAxiesByInputNames(other->transform(), other->origin_reduce_axis_names()); std::string str_target_trans = isl_map_to_str(target_transform.get()); std::string this_tensor_name = isl_set_get_tuple_name(domain_.get()); isl::ctx this_ctx = domain_.ctx(); @@ -1121,6 +1147,7 @@ void Stage::CopyTransform(Stage *other, int level) { VLOG(2) << "Target transform is : " << isl_map_to_str(other->transform().get()); VLOG(2) << "CopyTransform Level is : " << level; transform_ = res_map; + return; } void Stage::CopyLoopInfo(std::map target_forloop_infos, const isl::map &target_transform) { diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index 208ded1532d2f5..ba39386d455256 100755 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -245,7 +245,11 @@ class Stage : public Object { ir::Tensor CacheRead2(const std::string& memory_type, std::vector& readers, poly::StageMap stages); - void ComputeAt2(Stage* other, int level, ComputeAtKind kind = kComputeAtBefore); + void ComputeAt2(Stage* other, int level); + + void ComputeAt3(Stage* other, int level); + + void ShowISL(); void AddForLoopInTransform(std::vector>& indices); /** diff --git a/cinn/runtime/cuda/cuda_util.cc b/cinn/runtime/cuda/cuda_util.cc old mode 100644 new mode 100755 diff --git a/python/tests/fake_model/resnet_model.py b/python/tests/fake_model/resnet_model.py index fe962f03dd9bd7..76918b76bde4bd 100644 --- a/python/tests/fake_model/resnet_model.py +++ b/python/tests/fake_model/resnet_model.py @@ -10,12 +10,12 @@ resnet_input = fluid.layers.data( name="resnet_input", append_batch_size=False, - shape=[2, 160, 7, 7], + shape=[1, 160, 7, 7], dtype='float32') label = fluid.layers.data( name="label", append_batch_size=False, - shape=[2, 960, 7, 7], + shape=[1, 960, 7, 7], dtype='float32') d = fluid.layers.relu6(resnet_input) f = fluid.layers.conv2d( diff --git a/python/tests/test_frontend.py b/python/tests/test_frontend.py index 1baec6069caa68..14818f3dcc4106 100755 --- a/python/tests/test_frontend.py +++ b/python/tests/test_frontend.py @@ -50,8 +50,8 @@ def paddle_verify(self, result): exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) - x = np.array(result[0]).reshape((2, 24, 56, 56)).astype("float32") - y = np.array(result[1]).reshape((2, 24, 56, 56)).astype("float32") + x = np.array(result[0]).reshape((1, 24, 56, 56)).astype("float32") + y = np.array(result[1]).reshape((1, 24, 56, 56)).astype("float32") output = exe.run(feed={"A": x, "B": y}, fetch_list=[res]) output = np.array(output).reshape(-1) print("result in paddle_verify: \n") @@ -66,8 +66,8 @@ def paddle_verify(self, result): def test_basic(self): prog = Program() - a = Variable("A").set_type(Float(32)).set_shape([2, 24, 56, 56]) - b = Variable("B").set_type(Float(32)).set_shape([2, 24, 56, 56]) + a = Variable("A").set_type(Float(32)).set_shape([1, 24, 56, 56]) + b = Variable("B").set_type(Float(32)).set_shape([1, 24, 56, 56]) c = prog.add(a, b) d = prog.relu(c) e = Variable("E").set_type(Float(32)).set_shape([144, 24, 1, 1]) @@ -84,11 +84,10 @@ def test_basic(self): for i in range(prog.size()): print(prog[i]) tensor_data = [ - np.random.random([2, 24, 56, 56]).astype("float32"), - np.random.random([2, 24, 56, 56]).astype("float32"), + np.random.random([1, 24, 56, 56]).astype("float32"), + np.random.random([1, 24, 56, 56]).astype("float32"), np.random.random([144, 24, 1, 1]).astype("float32") ] - result = prog.build_and_get_output(self.target, [a, b, e], tensor_data, h) result = result.numpy(self.target).reshape(-1) diff --git a/python/tests/test_op_benchmark.py b/python/tests/test_op_benchmark.py index f528c34d9927d5..190481c3b00bae 100755 --- a/python/tests/test_op_benchmark.py +++ b/python/tests/test_op_benchmark.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import paddle +import paddle.fluid as fluid from cinn.frontend import * from cinn import Target from cinn.framework import * @@ -22,22 +24,236 @@ def setUp(self): else: self.target = DefaultHostTarget() - def atest_conv2d(self): + def paddle_verify(self, result): + paddle.enable_static() + + a = fluid.layers.data(name='A', shape=[128, 28, 28], dtype='float32') + e = fluid.initializer.NumpyArrayInitializer( + np.array(result[1]).reshape((256, 128, 1, 1)).astype("float32")) + res = fluid.layers.conv2d( + input=a, + num_filters=256, + filter_size=1, + stride=2, + padding=0, + dilation=1, + param_attr=e) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + + x = np.array(result[0]).reshape((1, 128, 28, 28)).astype("float32") + output = exe.run(feed={"A": x}, fetch_list=[res]) + output = np.array(output).reshape(-1) + print("result in conv2d paddle_verify: \n") + for i in range(0, output.shape[0]): + if np.abs(output[i] - result[len(result) - 1][i]) > 1e-4: + print("Error! ", i, "-th data has diff with target data:\n", + output[i], " vs: ", result[len(result) - 1][i], + ". Diff is: ", output[i] - result[len(result) - 1][i]) + self.assertTrue( + np.allclose(result[len(result) - 1], output, atol=1e-4)) + + def test_conv2d(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([1, 128, 28, 28]) + b = Variable("E").set_type(Float(32)).set_shape([256, 128, 1, 1]) + c = prog.conv2d(a, b, { + "stride": [2, 2], + "dilation": [1, 1], + "padding": [0, 0] + }) + tensor_data = [ + np.random.random([1, 128, 28, 28]).astype("float32"), + np.random.random([256, 128, 1, 1]).astype("float32") + ] + result = prog.test_benchmark( + self.target, [a, b], tensor_data, c, 20000, + "TESTING [conv2d] time cost with shape [1, 128, 28, 28]...") + result = result.numpy(self.target).reshape(-1) + tensor_data.append(result) + self.paddle_verify(tensor_data) + + def test_conv2d(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([1, 128, 28, 28]) + b = Variable("E").set_type(Float(32)).set_shape([256, 128, 1, 1]) + c = prog.conv2d(a, b, { + "stride": [2, 2], + "dilation": [1, 1], + "padding": [0, 0] + }) + tensor_data = [ + np.random.random([1, 128, 28, 28]).astype("float32"), + np.random.random([256, 128, 1, 1]).astype("float32") + ] + result = prog.test_benchmark( + self.target, [a, b], tensor_data, c, 20000, + "TESTING [conv2d] time cost with shape [1, 128, 28, 28]...") + result = result.numpy(self.target).reshape(-1) + tensor_data.append(result) + self.paddle_verify(tensor_data) + + def atest_conv2d3(self): prog = Program() - a = Variable("A").set_type(Float(32)).set_shape([2, 512, 7, 7]) - b = Variable("E").set_type(Float(32)).set_shape([512, 512, 3, 3]) + a = Variable("X").set_type(Float(32)).set_shape([1, 128, 28, 28]) + b = Variable("Y").set_type(Float(32)).set_shape([256, 128, 1, 1]) c = prog.conv2d(a, b, { - "stride": [1, 1], + "stride": [2, 2], "dilation": [1, 1], - "padding": [1, 1] + "padding": [0, 0] }) tensor_data = [ - np.random.random([2, 512, 7, 7]).astype("float32"), - np.random.random([512, 512, 3, 3]).astype("float32") + np.random.random([1, 128, 28, 28]).astype("float32"), + np.random.random([256, 128, 1, 1]).astype("float32") ] result = prog.test_benchmark( - self.target, [a, b], tensor_data, c, 2000, - "TESTING [conv2d] time cost with shape [2,512,7,7]...") + self.target, [a, b], tensor_data, c, 10000, + "TESTING [conv2d] time cost with shape [1, 128, 28, 28]...") + result = prog.test_benchmark_with_code( + self.target, [a, b], tensor_data, c, 20000, + "TESTING [conv2d of tvm schedule] time cost with shape [1, 128, 28, 28]...", + """ +extern "C" { + +#include "cinn_cuda_runtime_source.cuh" + +#ifdef __CUDACC_RTC__ +typedef int int32_t; +typedef char int8_t; +#endif + + + +__global__ +void fn_conv2d_0_kernel(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ Conv2d_nchw_out) +{ + float _COD_cache_write_out [ ((1 * (((((1 * 1) * 256) * 1) * 1) / 8)) / 16) ]; + float* COD_cache_write_out = _COD_cache_write_out; + float* COD_cache_write_out__reduce_init = _COD_cache_write_out; + for (int32_t i = 0; i < 1; i += 1) { + if ((blockIdx.z < 8)) { + if ((blockIdx.y < 14)) { + if ((threadIdx.z < 16)) { + if ((threadIdx.x < 14)) { + for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) { + COD_cache_write_out__reduce_init[j_inner] = 0; + }; + }; + }; + }; + }; + }; + for (int32_t fc_outer = 0; fc_outer < 16; fc_outer += 1) { + for (int32_t fc_inner = 0; fc_inner < 8; fc_inner += 1) { + for (int32_t fy = 0; fy < 1; fy += 1) { + for (int32_t fx = 0; fx < 1; fx += 1) { + for (int32_t i = 0; i < 1; i += 1) { + if ((blockIdx.z < 8)) { + if ((blockIdx.y < 14)) { + if ((threadIdx.z < 16)) { + if ((threadIdx.x < 14)) { + for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) { + COD_cache_write_out[j_inner] = (COD_cache_write_out[j_inner] + (((((((((blockIdx.y * 2) + fy) >= 0) && (((blockIdx.y * 2) + fy) < 28)) && (((threadIdx.x * 2) + fx) >= 0)) && (((threadIdx.x * 2) + fx) < 28))) ? X[((784 * ((((32 * blockIdx.z) + ((2 * threadIdx.z) + j_inner)) / 256) * 128)) + ((56 * blockIdx.y) + ((784 * fc_inner) + ((6272 * fc_outer) + ((28 * fy) + ((100352 * i) + ((2 * threadIdx.x) + fx)))))))] : 0) * (((((fy % 1) == 0) && ((fx % 1) == 0))) ? Y[((4096 * blockIdx.z) + ((8 * fc_outer) + ((128 * j_inner) + ((256 * threadIdx.z) + fc_inner))))] : 0))); + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; + for (int32_t i = 0; i < 1; i += 1) { + if ((blockIdx.z < 8)) { + if ((blockIdx.y < 14)) { + if ((threadIdx.z < 16)) { + if ((threadIdx.x < 14)) { + for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) { + Conv2d_nchw_out[((14 * blockIdx.y) + ((6272 * blockIdx.z) + ((50176 * i) + ((196 * j_inner) + ((392 * threadIdx.z) + threadIdx.x)))))] = COD_cache_write_out[j_inner]; + }; + }; + }; + }; + }; + }; +} + +} + """) + result = result.numpy(self.target).reshape(-1) + tensor_data.append(result) + self.paddle_verify(tensor_data) + + def atest_conv2d2(self): + prog = Program() + a = Variable("placeholder").set_type(Float(32)).set_shape( + [1, 128, 28, 28]) + b = Variable("placeholder1").set_type(Float(32)).set_shape( + [256, 128, 1, 1]) + c = prog.conv2d(a, b, { + "stride": [2, 2], + "dilation": [1, 1], + "padding": [0, 0] + }) + tensor_data = [ + np.random.random([1, 128, 28, 28]).astype("float32"), + np.random.random([256, 128, 1, 1]).astype("float32") + ] + result = prog.test_benchmark_with_code( + self.target, [a, b], tensor_data, c, 20000, + "TESTING [conv2d of tvm schedule] time cost with shape [1, 128, 28, 28]...", + """ +extern "C" { + +#include "cinn_cuda_runtime_source.cuh" + +#ifdef __CUDACC_RTC__ +typedef int int32_t; +typedef char int8_t; +#endif + + + +__global__ void fn_conv2d_0_kernel(float* __restrict__ placeholder, float* __restrict__ placeholder1, float* __restrict__ Conv2d_nchw_out) { + float compute_local[2]; + __shared__ float pad_temp_shared[216]; + __shared__ float placeholder_shared[256]; + for (int ff_c_init = 0; ff_c_init < 2; ++ff_c_init) { + compute_local[(ff_c_init)] = 0.000000e+00f; + } + for (int rc_outer = 0; rc_outer < 16; ++rc_outer) { + __syncthreads(); + if (((((int)threadIdx.z) * 14) + ((int)threadIdx.x)) < 216) { + pad_temp_shared[(((((int)threadIdx.z) * 14) + ((int)threadIdx.x)))] = placeholder[(((((rc_outer * 6272) + ((((((int)threadIdx.z) * 14) + ((int)threadIdx.x)) / 27) * 784)) + (((int)blockIdx.y) * 56)) + (((((int)threadIdx.z) * 14) + ((int)threadIdx.x)) % 27)))]; + } + for (int ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner = 0; ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner < 2; ++ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) { + if (((((int)threadIdx.z) * 2) + (((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) >> 3)) < 32) { + if ((((((int)threadIdx.z) * 16) + (((int)threadIdx.x) * 2)) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) < 256) { + if (((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) < 16) { + placeholder_shared[((((((int)threadIdx.z) * 16) + (((int)threadIdx.x) * 2)) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner))] = placeholder1[((((((((int)blockIdx.z) * 4096) + (((int)threadIdx.z) * 256)) + ((((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) >> 3) * 128)) + (rc_outer * 8)) + (((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) & 7)))]; + } + } + } + } + __syncthreads(); + for (int rc_inner = 0; rc_inner < 8; ++rc_inner) { + for (int ff_c = 0; ff_c < 2; ++ff_c) { + compute_local[(ff_c)] = (compute_local[(ff_c)] + (pad_temp_shared[(((rc_inner * 27) + (((int)threadIdx.x) * 2)))] * placeholder_shared[((((((int)threadIdx.z) * 16) + (ff_c * 8)) + rc_inner))])); + } + } + } + for (int ff_inner_inner_inner = 0; ff_inner_inner_inner < 2; ++ff_inner_inner_inner) { + Conv2d_nchw_out[((((((((int)blockIdx.z) * 6272) + (((int)threadIdx.z) * 392)) + (ff_inner_inner_inner * 196)) + (((int)blockIdx.y) * 14)) + ((int)threadIdx.x)))] = compute_local[(ff_inner_inner_inner)]; + } +} + +} + """) + result = result.numpy(self.target).reshape(-1) + tensor_data.append(result) + self.paddle_verify(tensor_data) def atest_softmax(self): prog = Program() @@ -147,7 +363,7 @@ def atest_matmul(self): } }''') - def test_pool2d(self): + def atest_pool2d(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([2, 64, 112, 112]) c = prog.pool2d( diff --git a/python/tests/test_resnet.py b/python/tests/test_resnet.py index 29cc88ae3f8af8..7a8156b6f0cb7e 100755 --- a/python/tests/test_resnet.py +++ b/python/tests/test_resnet.py @@ -27,7 +27,7 @@ def setUp(self): self.model_dir = model_dir - self.x_shape = [2, 160, 7, 7] + self.x_shape = [1, 160, 7, 7] def get_paddle_inference_result(self, data): config = fluid.core.AnalysisConfig(self.model_dir)