From ad6e41ce30b9cd41f1ccf71c4a47546edf96e621 Mon Sep 17 00:00:00 2001 From: haozech Date: Tue, 1 Dec 2020 19:21:07 +0800 Subject: [PATCH] Add mulbias and change mul's behavior (#295) --- cinn/backends/raw_cuda_code_test.cu | 0 cinn/frontend/paddle_model_to_program.cc | 84 +++++++- cinn/frontend/paddle_model_to_program.h | 5 +- cinn/frontend/syntax.cc | 9 + cinn/frontend/syntax.h | 6 + cinn/frontend/syntax_test.cc | 2 +- .../framework/cuda_graph_compiler_test.cc | 4 +- cinn/hlir/framework/graph_compiler.h | 22 +- cinn/hlir/framework/instruction.h | 12 +- cinn/hlir/framework/tensor.h | 2 +- cinn/hlir/op/nn.cc | 6 +- cinn/hlir/op/transform.cc | 167 ++++++++++++++- cinn/hlir/pe/nn.cc | 25 +++ cinn/hlir/pe/nn.h | 11 + cinn/hlir/pe/transform.cc | 29 ++- cinn/hlir/pe/transform.h | 8 + cinn/pybind/frontend.cc | 2 + cinn/pybind/poly.cc | 2 +- cinn/runtime/cuda/cuda_util.cc | 4 +- python/tests/test_op_benchmark.py | 195 +++++++++++++++--- .../tvm_benchmark/tvm_graph_with_single_op.py | 118 +++++++++-- 21 files changed, 635 insertions(+), 78 deletions(-) mode change 100644 => 100755 cinn/backends/raw_cuda_code_test.cu mode change 100644 => 100755 cinn/frontend/syntax.cc mode change 100644 => 100755 cinn/frontend/syntax_test.cc mode change 100644 => 100755 cinn/hlir/framework/graph_compiler.h mode change 100644 => 100755 cinn/hlir/framework/instruction.h mode change 100644 => 100755 cinn/hlir/op/transform.cc mode change 100644 => 100755 cinn/hlir/pe/transform.cc mode change 100644 => 100755 cinn/pybind/frontend.cc mode change 100644 => 100755 cinn/pybind/poly.cc mode change 100644 => 100755 cinn/runtime/cuda/cuda_util.cc diff --git a/cinn/backends/raw_cuda_code_test.cu b/cinn/backends/raw_cuda_code_test.cu old mode 100644 new mode 100755 diff --git a/cinn/frontend/paddle_model_to_program.cc b/cinn/frontend/paddle_model_to_program.cc index fd1f3d3eeec15..4400f338c9978 100644 --- a/cinn/frontend/paddle_model_to_program.cc +++ b/cinn/frontend/paddle_model_to_program.cc @@ -9,6 +9,28 @@ namespace frontend { using utils::Join; using utils::TransValidVarName; +void MoveData(float* data, int i, int M, int N) { + float temp = data[i]; + int cur = i; // current data index + int pre = (cur % M) * N + cur / M; + while (pre != i) { + data[cur] = data[pre]; + cur = pre; + pre = (cur % M) * N + cur / M; + } + data[cur] = temp; +} + +void TransposeData(float* data, int M, int N) { + for (int i = 0; i < M * N; i++) { + int next = (i % N) * M + i / N; + while (next > i) // next < 1 implies duplicate + next = (next % N) * M + next / N; + if (next == i) // process current ring + MoveData(data, i, M, N); + } +} + void PaddleModelToProgram::AddOpMapper_feed() { op_mappers_["feed"] = [&](const paddle::cpp::OpDesc& op_desc) { auto outs = op_desc.Output("Out"); @@ -65,11 +87,13 @@ void PaddleModelToProgram::AddOpMapper_mul() { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); - auto y_name = op_desc.Input("Y").front(); - auto x = GetVar(utils::TransValidVarName(x_name)); + auto y_name = op_desc.Input("Y").front(); + auto x = GetVar(utils::TransValidVarName(x_name)); + TransposeVar(TransValidVarName(y_name)); auto y = GetVar(utils::TransValidVarName(y_name)); int x_num_col_dims = op_desc.GetAttr("x_num_col_dims"); int y_num_col_dims = op_desc.GetAttr("y_num_col_dims"); + CHECK_EQ(y_num_col_dims, 1) << "The y_num_col_dims of mul is not 1! Please check."; VLOG(4) << "Mul x_num_col_dims: " << x_num_col_dims; VLOG(4) << "Mul y_num_col_dims: " << y_num_col_dims; VLOG(4) << "x shape: " << utils::Join(x->shape, ","); @@ -364,6 +388,56 @@ void PaddleModelToProgram::AddOp(const paddle::cpp::OpDesc& op_desc) { LOG(FATAL) << "Not supported op [" << op_desc.Type() << "] found"; } +void PaddleModelToProgram::TransposeVar(const std::string& name) { + CheckVarNameValid(name); + auto* var = scope_->FindVar(name); + if (var) { + auto& tensor = std::get(*var); + if (target_.arch == Target::Arch::X86) { + float* data = tensor->mutable_data(target_); + CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check."; + TransposeData(data, tensor->shape().data()[0], tensor->shape().data()[1]); + } else if (target_.arch == Target::Arch::NVGPU) { +#ifdef CINN_WITH_CUDA + std::vector data(tensor->shape().numel()); + CUDA_CALL(cudaMemcpy(data.data(), + reinterpret_cast(tensor->mutable_data(target_)), + tensor->shape().numel() * sizeof(float), + cudaMemcpyDeviceToHost)); +#else + LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; +#endif + CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check."; + + TransposeData(data.data(), tensor->shape().data()[0], tensor->shape().data()[1]); + +#ifdef CINN_WITH_CUDA + CUDA_CALL(cudaMemcpy(reinterpret_cast(tensor->mutable_data(target_)), + data.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); +#else + LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; +#endif + + } else { + CINN_NOT_IMPLEMENTED + } + + Variable var; + var.set_id(name); + std::vector reverse_shape = tensor->shape().data(); + std::reverse(reverse_shape.begin(), reverse_shape.end()); + tensor->shape().SetData(reverse_shape); + var->shape = tensor->shape().data(); + // TODO(Superjomn) Make this determined by model. + var->type = Float(32); + AddVar(name, var, true); + } else { + LOG(FATAL) << "No var called [" << name << "] exists"; + } +} + Variable PaddleModelToProgram::GetVar(const std::string& name) { CheckVarNameValid(name); @@ -399,9 +473,11 @@ std::unique_ptr PaddleModelToProgram::operator()(const std::string& mod return std::move(program_); } -void PaddleModelToProgram::AddVar(const std::string& name, const Variable& var) { +void PaddleModelToProgram::AddVar(const std::string& name, const Variable& var, bool replace) { CheckVarNameValid(name); - CHECK(!var_map_.count(name)) << "Duplicate variable [" << name << "] found"; + if (replace == false) { + CHECK(!var_map_.count(name)) << "Duplicate variable [" << name << "] found"; + } var_map_[name] = var; } diff --git a/cinn/frontend/paddle_model_to_program.h b/cinn/frontend/paddle_model_to_program.h index d21bad5624db0..e577170839900 100644 --- a/cinn/frontend/paddle_model_to_program.h +++ b/cinn/frontend/paddle_model_to_program.h @@ -9,6 +9,7 @@ #include #include +#include "cinn/backends/cuda_util.h" #include "cinn/common/common.h" #include "cinn/common/context.h" #include "cinn/common/object.h" @@ -73,10 +74,12 @@ class PaddleModelToProgram { const std::unordered_map& var_model_to_program_map() { return var_model_to_program_map_; } protected: - void AddVar(const std::string& name, const Variable& var); + void AddVar(const std::string& name, const Variable& var, bool replace = false); Variable GetVar(const std::string& name); + void TransposeVar(const std::string& name); + private: std::unordered_map> op_mappers_; std::unique_ptr program_; diff --git a/cinn/frontend/syntax.cc b/cinn/frontend/syntax.cc old mode 100644 new mode 100755 index e70a82c8bea32..a46f6229400a1 --- a/cinn/frontend/syntax.cc +++ b/cinn/frontend/syntax.cc @@ -210,6 +210,15 @@ Variable Program::mul(const Variable& a, const Variable& b, int x_num_col_dims, return instr.GetOutput(0); } +Variable Program::mulbias( + const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims, int y_num_col_dims) { + Instruction instr("mulbias", {a, b, c}); + instr.SetAttr("x_num_col_dims", x_num_col_dims); + instr.SetAttr("y_num_col_dims", y_num_col_dims); + AppendInstruction(instr); + return instr.GetOutput(1); +} + std::string _Instruction_::debug_string() const { struct Visit { std::stringstream& s_; diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index 80d72077a36cc..b27ec9899c8e9 100644 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -172,6 +172,12 @@ struct Program { */ Variable mul(const Variable& a, const Variable& b, int x_num_col_dims = 1, int y_num_col_dims = 1); + /** + * Multiply two matrix and add a bias. + */ + Variable mulbias( + const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims = 1, int y_num_col_dims = 1); + /** * Add two tensors element-wise. */ diff --git a/cinn/frontend/syntax_test.cc b/cinn/frontend/syntax_test.cc old mode 100644 new mode 100755 index 2473e92e60792..8e69b2a46149a --- a/cinn/frontend/syntax_test.cc +++ b/cinn/frontend/syntax_test.cc @@ -104,7 +104,7 @@ TEST(syntax, program_execute_fc) { const int N = 24; Placeholder a(Float(32), {B, M, K}, "A"); - Placeholder w(Float(32), {K, N}, "W"); // weight + Placeholder w(Float(32), {N, K}, "W"); // weight Placeholder b(Float(32), {N}, "B"); // bias Program program; diff --git a/cinn/hlir/framework/cuda_graph_compiler_test.cc b/cinn/hlir/framework/cuda_graph_compiler_test.cc index 6507c946c83fd..1838dc9df2a63 100644 --- a/cinn/hlir/framework/cuda_graph_compiler_test.cc +++ b/cinn/hlir/framework/cuda_graph_compiler_test.cc @@ -43,7 +43,7 @@ std::vector test_mul(const std::vector& A, const std::vectorshape = {M.as_int32(), K.as_int32()}; - b->shape = {K.as_int32(), N.as_int32()}; + b->shape = {N.as_int32(), K.as_int32()}; a->type = t; b->type = t; auto c = prog.mul(a, b); diff --git a/cinn/hlir/framework/graph_compiler.h b/cinn/hlir/framework/graph_compiler.h old mode 100644 new mode 100755 index 989f2594693d1..c21000501e675 --- a/cinn/hlir/framework/graph_compiler.h +++ b/cinn/hlir/framework/graph_compiler.h @@ -6,6 +6,7 @@ #include #include "cinn/backends/compiler.h" +#include "cinn/backends/cuda_util.h" #include "cinn/common/macros.h" #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/instruction.h" @@ -13,6 +14,7 @@ #include "cinn/hlir/framework/scope.h" #include "cinn/ir/lowered_func.h" #include "cinn/lang/packed_func.h" +#include "cinn/utils/timer.h" namespace cinn { namespace hlir { @@ -47,14 +49,28 @@ class Program { VLOG(3) << out << " "; } ins->Run(); + CUDA_CALL(cudaDeviceSynchronize()); } } void ExecuteTest(int repeat_) { - CHECK_EQ(instrs_.size(), 1); - for (auto& ins : instrs_) { - ins->RunTest(repeat_); + cinn::utils::Timer timer1; + for (int i = 0; i < 100; i++) { + for (auto& ins : instrs_) { + ins->RunTest(repeat_); + } + } + timer1.Start(); + for (int i = 0; i < repeat_; i++) { + for (auto& ins : instrs_) { + ins->RunTest(repeat_); + } } + + CUDA_CALL(cudaDeviceSynchronize()); + double test_op_time = timer1.Stop() / repeat_; + + LOG(INFO) << "Repeat times: [" << repeat_ << "], average op time: [" << test_op_time << "] ms"; } /** * Get the number of instructions. diff --git a/cinn/hlir/framework/instruction.h b/cinn/hlir/framework/instruction.h old mode 100644 new mode 100755 index 33f469c01f336..be1107e1af13c --- a/cinn/hlir/framework/instruction.h +++ b/cinn/hlir/framework/instruction.h @@ -3,6 +3,7 @@ #include #include +#include "cinn/backends/cuda_util.h" #include "cinn/common/test_helper.h" #include "cinn/hlir/framework/scope.h" #include "cinn/utils/timer.h" @@ -43,16 +44,7 @@ class Instruction { void RunTest(int repeat_) { CHECK(fn_) << "The LoweredFunc address should be set first by calling SetLoweredFunc method"; auto& pod_args = PreparePodArgs(); - cinn::utils::Timer timer; - for (int i = 0; i < 100; i++) { - fn_(pod_args.data(), pod_args.size()); - } - timer.Start(); - for (int i = 0; i < repeat_; i++) { - fn_(pod_args.data(), pod_args.size()); - } - double test_op_time = timer.Stop() / repeat_; - LOG(INFO) << "Repeat times: [" << repeat_ << "], average op run time: [" << test_op_time << "] ms"; + fn_(pod_args.data(), pod_args.size()); } /** diff --git a/cinn/hlir/framework/tensor.h b/cinn/hlir/framework/tensor.h index 8c76d75f3a523..dd7c413df71ea 100644 --- a/cinn/hlir/framework/tensor.h +++ b/cinn/hlir/framework/tensor.h @@ -38,7 +38,7 @@ class _Tensor_ : public Object { public: _Tensor_() : buffer_(std::make_shared()) {} - const Shape& shape() const { return shape_; } + Shape& shape() { return shape_; } void Resize(const Shape& shape) { shape_ = shape; diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index d3bd93c6ccba5..32706bacf6031 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -196,6 +196,8 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, if (target.arch == Target::Arch::NVGPU) { Expr Out = arg_pack[2]; CHECK(Out.as_tensor()); + // pe::CudaScheduleConv(stages, input_pad.as_tensor_ref(), weights_dilation.as_tensor_ref(), Out.as_tensor_ref(), + // target); stages[Out.as_tensor_ref()]->Split(1, 2); stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); @@ -440,9 +442,7 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr Expr Out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; CHECK(Out.as_tensor()); - pe::CudaSplitSchedule(stages[Out.as_tensor_ref()], output_shapes.back()); - stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); - stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); } *ret = arg_pack; }); diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc old mode 100644 new mode 100755 index e90281a777a8f..a1486e8f66743 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -3,6 +3,7 @@ #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/nn.h" #include "cinn/ir/ir_printer.h" namespace cinn { @@ -165,13 +166,14 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, } } new_xshape.push_back(check_dim); - new_yshape.push_back(check_dim); + for (int i = 0; i < B_tensor->shape.size(); i++) { - if (i >= y_num_col_dims) { + if (i < y_num_col_dims) { output_shape.push_back(B_tensor->shape[i]); new_yshape.push_back(B_tensor->shape[i]); } } + new_yshape.push_back(check_dim); Var axis_k(check_dim, UniqName("axis_k")); auto new_A = A_tensor->Reshape(new_xshape, stages); auto new_B = B_tensor->Reshape(new_yshape, stages); @@ -191,9 +193,7 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, Expr Out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; CHECK(Out.as_tensor()); - stages[Out.as_tensor_ref()]->Split(1, 2); - stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); - stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); + pe::CudaScheduleMul(stages, Out.as_tensor_ref(), output_shapes.back(), target); } *ret = arg_pack; }); @@ -204,9 +204,103 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, return strategy; } +std::shared_ptr StrategyForMulBias(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute mul_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of Mul compute is empty! Please check.\n"; + CINNValuePack a = args[0]; + CHECK_GE(a.size(), 3U) << "at least 2 input tensors for Mul compute\n"; + Expr A = a[0]; + Expr B = a[1]; + Expr C = a[2]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + CHECK(C.as_tensor()); + auto attr_store = attrs.attr_store; + int x_num_col_dims = 1; + int y_num_col_dims = 1; + for (auto &iter : attrs.attr_store) { + if (iter.first == "x_num_col_dims") { + x_num_col_dims = std::get(iter.second); + } else if (iter.first == "y_num_col_dims") { + y_num_col_dims = std::get(iter.second); + } else { + LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; + } + } + auto A_tensor = A.as_tensor_ref(); + auto B_tensor = B.as_tensor_ref(); + auto C_tensor = C.as_tensor_ref(); + auto stages = CreateStages({A_tensor, B_tensor, C_tensor}); + std::vector output_shape; + std::vector new_xshape; + std::vector new_yshape; + Expr check_dim(1); + for (int i = 0; i < A_tensor->shape.size(); i++) { + if (i < x_num_col_dims) { + output_shape.push_back(A_tensor->shape[i]); + new_xshape.push_back(A_tensor->shape[i]); + } else { + check_dim = check_dim * A_tensor->shape[i]; + } + } + new_xshape.push_back(check_dim); + + for (int i = 0; i < B_tensor->shape.size(); i++) { + if (i < y_num_col_dims) { + output_shape.push_back(B_tensor->shape[i]); + new_yshape.push_back(B_tensor->shape[i]); + } + } + new_yshape.push_back(check_dim); + Var axis_k(check_dim, UniqName("axis_k")); + auto new_A = A_tensor->Reshape(new_xshape, stages); + auto new_B = B_tensor->Reshape(new_yshape, stages); + + auto out = pe::MulBias(new_A, new_B, C_tensor, x_num_col_dims, output_shape, axis_k, UniqName("MulBias_output")); + + std::vector res; + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(CINNValue(t)); + } + res.push_back(CINNValue(stages)); + CHECK(!out_type.empty()) << "Output type of MulBias is empty! Please check.\n"; + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule mul_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of mul schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 3UL); + Expr Temp = arg_pack[0]; + Expr Out = arg_pack[1]; + poly::StageMap stages = arg_pack[2]; + CHECK(Out.as_tensor()); + CHECK(Temp.as_tensor()); + if (target.arch == Target::Arch::NVGPU) { + pe::CudaScheduleMul(stages, Temp.as_tensor_ref(), output_shapes.back(), target); + pe::CudaScheduleMul(stages, Out.as_tensor_ref(), output_shapes.back(), target); + /* stages[Out.as_tensor_ref()]->Split(1, 2); + stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); + stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); */ + // pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(),target); + } + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(mul_compute, mul_schedule, "strategy.mulbias.x86", 1); + + return strategy; +} + std::vector> InferShapeForMul(const std::vector> &inputs_shape, const framework::NodeAttr &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + // CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; CHECK_GE(inputs_shape[0].size(), 2U) << "Input matrix X's dim should be >= 2! Please check."; CHECK_GE(inputs_shape[1].size(), 2U) << "Input matrix Y's dim should be >= 2! Please check."; @@ -234,9 +328,9 @@ std::vector> InferShapeForMul(const std::vector InferDtypeForMul(const std::vector &inputs_type, const f return res; } +std::vector> InferShapeForMulBias(const std::vector> &inputs_shape, + const framework::NodeAttr &attrs) { + // CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + CHECK_GE(inputs_shape[0].size(), 2U) << "Input matrix X's dim should be >= 2! Please check."; + CHECK_GE(inputs_shape[1].size(), 2U) << "Input matrix Y's dim should be >= 2! Please check."; + + std::vector output_shape; + int x_num_col_dims = 1; + int y_num_col_dims = 1; + for (auto &iter : attrs.attr_store) { + if (iter.first == "x_num_col_dims") { + x_num_col_dims = std::get(iter.second); + } else if (iter.first == "y_num_col_dims") { + y_num_col_dims = std::get(iter.second); + } else { + LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; + } + } + int check_dim_x = 1; + int check_dim_y = 1; + for (int i = 0; i < inputs_shape[0].size(); i++) { + if (i < x_num_col_dims) { + output_shape.push_back(inputs_shape[0][i]); + } else { + check_dim_x = check_dim_x * inputs_shape[0][i]; + } + } + + for (int i = 0; i < inputs_shape[1].size(); i++) { + if (i < y_num_col_dims) { + output_shape.push_back(inputs_shape[1][i]); + } else { + check_dim_y = check_dim_y * inputs_shape[1][i]; + } + } + CHECK_EQ(check_dim_x, check_dim_y) << "For matrix multiply: X * Y, second dim of X's shape :[" << check_dim_x + << "] should be equal to first dim of Y's shape :[" << check_dim_y + << "]! Please Check!"; + + std::vector> res{output_shape, output_shape}; + return res; +} + +std::vector InferDtypeForMulBias(const std::vector &inputs_type, const framework::NodeAttr &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0], inputs_type[0]}; + return res; +} + } // namespace op } // namespace hlir } // namespace cinn @@ -278,5 +421,13 @@ CINN_REGISTER_HELPER(transform_ops) { .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForMul)) .set_support_level(4); + CINN_REGISTER_OP(mulbias) + .describe("This operator is used to perform matrix multiplication for input X and Y and add Z.") + .set_num_inputs(3) + .set_num_outputs(2) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForMulBias) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForMulBias)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForMulBias)) + .set_support_level(4); return true; } diff --git a/cinn/hlir/pe/nn.cc b/cinn/hlir/pe/nn.cc index c6341c6b654f7..338d4f193de96 100755 --- a/cinn/hlir/pe/nn.cc +++ b/cinn/hlir/pe/nn.cc @@ -24,6 +24,31 @@ using ir::Min; using ir::Select; using ir::Tensor; +void CudaScheduleMul(poly::StageMap stages, + ir::Tensor output, + const std::vector &output_shape, + const common::Target &target) { + stages[output]->Split(1, 2); + stages[output]->Bind(0, "blockIdx.x"); + stages[output]->Bind(1, "threadIdx.x"); + + return; +} + +void CudaScheduleConv(poly::StageMap stages, + ir::Tensor input_pad, + ir::Tensor kernel_dilation, + ir::Tensor output, + const common::Target &target) { + int num_thread = target.max_num_threads(); + stages[output]->Fuse(0, 1); + auto [Block_x, Thread_x] = stages[output]->Split(0, num_thread); + stages[output]->Bind(0, "blockIdx.x"); + stages[output]->Bind(1, "threadIdx.x"); + + return; +} + void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_shape, const common::Target &target) { CHECK_EQ(stage->n_out_dims(), stage->n_in_dims()) << "The dims of op are not equal"; int dims = stage->n_out_dims(); diff --git a/cinn/hlir/pe/nn.h b/cinn/hlir/pe/nn.h index afcf29d236b52..28b42e4740885 100644 --- a/cinn/hlir/pe/nn.h +++ b/cinn/hlir/pe/nn.h @@ -12,6 +12,17 @@ namespace cinn { namespace hlir { namespace pe { +void CudaScheduleMul(poly::StageMap stages, + ir::Tensor output, + const std::vector &output_shape, + const common::Target &target); + +void CudaScheduleConv(poly::StageMap stages, + ir::Tensor input_pad, + ir::Tensor kernel_dilation, + ir::Tensor output, + const common::Target &target); + void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_shape, const common::Target &target); void CudaSplitSchedule(poly::Stage *stage, const std::vector &output_shape); diff --git a/cinn/hlir/pe/transform.cc b/cinn/hlir/pe/transform.cc old mode 100644 new mode 100755 index 23615ee788a9a..4a6ca572752d5 --- a/cinn/hlir/pe/transform.cc +++ b/cinn/hlir/pe/transform.cc @@ -3,6 +3,7 @@ #include #include "cinn/common/cas.h" +#include "cinn/common/context.h" #include "cinn/common/ir_util.h" #include "cinn/ir/tensor.h" #include "cinn/lang/builtin.h" @@ -132,15 +133,39 @@ Tensor Mul(const Tensor& A, [=](const std::vector& indice) { std::vector A_indice; std::vector B_indice; - B_indice.push_back(axis_k); A_indice.insert(A_indice.begin(), indice.begin(), indice.begin() + x_num_col_dims); - B_indice.insert(B_indice.begin() + 1, indice.begin() + x_num_col_dims, indice.end()); + B_indice.insert(B_indice.begin(), indice.begin() + x_num_col_dims, indice.end()); A_indice.push_back(axis_k); + B_indice.push_back(axis_k); return lang::ReduceSum(A(A_indice) * B(B_indice), {axis_k}); }, name); } +std::vector MulBias(const Tensor& A, + const Tensor& B, + const Tensor& C, + int x_num_col_dims, + const std::vector& output_shape, + const Var& axis_k, + const std::string& name) { + auto temp = Compute( + output_shape, + [=](const std::vector& indice) { + std::vector A_indice; + std::vector B_indice; + A_indice.insert(A_indice.begin(), indice.begin(), indice.begin() + x_num_col_dims); + B_indice.insert(B_indice.begin(), indice.begin() + x_num_col_dims, indice.end()); + A_indice.push_back(axis_k); + B_indice.push_back(axis_k); + return lang::ReduceSum(A(A_indice) * B(B_indice), {axis_k}); + }, + UniqName("temp_out_mulbias")); + auto res = Compute( + output_shape, [=](const std::vector& indice) { return temp(indice) + C(indice); }, name); + return {temp, res}; +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/transform.h b/cinn/hlir/pe/transform.h index e7208126ad844..02fa632678266 100644 --- a/cinn/hlir/pe/transform.h +++ b/cinn/hlir/pe/transform.h @@ -46,6 +46,14 @@ ir::Tensor Mul(const ir::Tensor& A, const ir::Var& axis_k, const std::string& name); +std::vector MulBias(const ir::Tensor& A, + const ir::Tensor& B, + const ir::Tensor& C, + int x_num_col_dims, + const std::vector& output_shape, + const ir::Var& axis_k, + const std::string& name); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc old mode 100644 new mode 100755 index 2e43acf65a561..2a5e4cb98c855 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -68,9 +68,11 @@ void BindFrontend(pybind11::module *m) { .def("__getitem__", [](Program &self, int idx) { return self[idx]; }) .def("add", &Program::add) .def("mul", &Program::mul) + .def("mulbias", &Program::mulbias) .def("elementwise_add", &Program::elementwise_add) .def("relu", &Program::relu) .def("relu6", &Program::relu6) + .def("sigmoid", &Program::sigmoid) .def("scale", &Program::scale) .def("conv2d", &Program::conv2d) .def("batchnorm", &Program::batchnorm) diff --git a/cinn/pybind/poly.cc b/cinn/pybind/poly.cc old mode 100644 new mode 100755 index fcf7eec4a00b1..76950026ee484 --- a/cinn/pybind/poly.cc +++ b/cinn/pybind/poly.cc @@ -72,7 +72,7 @@ void BindStage(py::module *m) { .def("skew", &Stage::Skew) .def("ctrl_depend", &Stage::CtrlDepend) .def("cache_read", &Stage::CacheRead) - .def("cache_write", &Stage::CacheRead); + .def("cache_write", &Stage::CacheWrite); } void BindStageMap(py::module *m) { diff --git a/cinn/runtime/cuda/cuda_util.cc b/cinn/runtime/cuda/cuda_util.cc old mode 100644 new mode 100755 index 23c61ee624a0b..5fbda4d7c8ffd --- a/cinn/runtime/cuda/cuda_util.cc +++ b/cinn/runtime/cuda/cuda_util.cc @@ -5,6 +5,7 @@ #include "cinn/backends/cuda_util.h" #include "cinn/backends/extern_func_jit_register.h" #include "cinn/common/target.h" +#include "cinn/utils/timer.h" namespace cinn { namespace runtime { @@ -30,8 +31,6 @@ void cinn_call_cuda_kernel(void *kernel_fn, arr[i] = args[i].data_addr(); } } - VLOG(3) << "[CUDA] LaunchKernel grid_xyz is: " << grid_x << "," << grid_y << "," << grid_z; - VLOG(3) << "[CUDA] LaunchKernel block_xyz is: " << block_x << "," << block_y << "," << block_z; CUDA_DRIVER_CALL(cuLaunchKernel(static_cast(kernel_fn), grid_x, grid_y, @@ -43,6 +42,7 @@ void cinn_call_cuda_kernel(void *kernel_fn, static_cast(stream), reinterpret_cast(arr), nullptr)) + // CUDA_CALL(cudaDeviceSynchronize()); } } // namespace cuda diff --git a/python/tests/test_op_benchmark.py b/python/tests/test_op_benchmark.py index b29e713e0f4cf..669163ea61306 100755 --- a/python/tests/test_op_benchmark.py +++ b/python/tests/test_op_benchmark.py @@ -17,12 +17,9 @@ class TestBenchmark(unittest.TestCase): def setUp(self): - if enable_gpu == "ON": - self.target = DefaultNVGPUTarget() - else: - self.target = DefaultHostTarget() + self.target = DefaultNVGPUTarget() - def test_conv2d(self): + def atest_conv2d(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([2, 512, 7, 7]) b = Variable("E").set_type(Float(32)).set_shape([512, 512, 3, 3]) @@ -39,7 +36,7 @@ def test_conv2d(self): self.target, [a, b], tensor_data, c, 200, "TESTING [conv2d] time cost with shape [2,512,7,7]...") - def test_softmax(self): + def atest_softmax(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([1024, 2048]) c = prog.softmax(a, {}) @@ -48,7 +45,7 @@ def test_softmax(self): self.target, [a], tensor_data, c, 200, "TESTING [softmax] time cost with shape [1024,2048]...") - def test_matmul(self): + def atest_matmul(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([512, 512]) b = Variable("B").set_type(Float(32)).set_shape([512, 512]) @@ -61,7 +58,93 @@ def test_matmul(self): self.target, [a, b], tensor_data, c, 200, "TESTING [matmul] time cost with shape [512,512]...") - def test_pool2d(self): + def test_matmul1(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([128, 512]) + b = Variable("B").set_type(Float(32)).set_shape([256, 512]) + c = Variable("C").set_type(Float(32)).set_shape([128, 256]) + d = prog.mulbias(a, b, c, 1, 1) + tensor_data = [ + np.random.random([128, 512]).astype("float32"), + np.random.random([256, 512]).astype("float32"), + np.random.random([128, 256]).astype("float32") + ] + result = prog.test_benchmark( + self.target, [a, b, c], tensor_data, d, 200, + "TESTING [mulbias] time cost with shape [128,512]*[256,512]...") + + def test_matmul2(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([128, 512]) + b = Variable("B").set_type(Float(32)).set_shape([256, 512]) + c = Variable("C").set_type(Float(32)).set_shape([128, 256]) + d = prog.mul(a, b, 1, 1) + e = prog.add(d, c) + tensor_data = [ + np.random.random([128, 512]).astype("float32"), + np.random.random([256, 512]).astype("float32"), + np.random.random([128, 256]).astype("float32") + ] + result = prog.test_benchmark( + self.target, [a, b, c], tensor_data, e, 200, + "TESTING [mul and add] time cost with shape [128,512]*[256,512]..." + ) + + def test_matmul(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([512, 512]) + b = Variable("B").set_type(Float(32)).set_shape([512, 512]) + c = Variable("C").set_type(Float(32)).set_shape([512, 512]) + d = prog.mul(a, b, 1, 1) + # e = prog.add(d, c) + tensor_data = [ + np.random.random([512, 512]).astype("float32"), + np.random.random([512, 512]).astype("float32") + ] + result = prog.test_benchmark_with_code( + self.target, [a, b], tensor_data, d, 200, + "TESTING [matmul] time cost with shape [512,512]...", ''' + extern "C" { +#include "cinn_cuda_runtime_source.cuh" +#ifdef __CUDACC_RTC__ +typedef int int32_t; +typedef char int8_t; +#endif + + __global__ + void fn_mul_0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ Mul_output) + { + const float* A_reshape = A; + const float* B_reshape = B; + float* Mul_output__reduce_init = Mul_output; + if ((blockIdx.x < 512)) { + { + if ((threadIdx.x < 256)) { + { + for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) { + Mul_output__reduce_init[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] = 0; + }; + } + }; + } + }; + if ((blockIdx.x < 512)) { + { + if ((threadIdx.x < 256)) { + { + for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) { + for (int32_t axis_k = 0; axis_k < 512; axis_k += 1) { + Mul_output[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] = (Mul_output[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] + (A_reshape[((512 * blockIdx.x) + axis_k)] * B_reshape[((512 * axis_k) + ((2 * threadIdx.x) + j_inner))])) + Mul_output[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))]; + }; + }; + } + }; + } + }; + } + }''') + + def atest_pool2d(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([2, 64, 112, 112]) c = prog.pool2d( @@ -75,20 +158,46 @@ def test_pool2d(self): self.target, [a], tensor_data, c, 200, "TESTING [pool2d] time cost with shape [2, 64, 112, 112]...") - def test_elementwise1(self): + def atest_elementwise1(self): prog = Program() - a = Variable("A").set_type(Float(32)).set_shape([2, 512, 7, 7]) - b = Variable("B").set_type(Float(32)).set_shape([2, 512, 7, 7]) + a = Variable("A").set_type(Float(32)).set_shape([64, 64]) + b = Variable("B").set_type(Float(32)).set_shape([64, 64]) c = prog.add(a, b) tensor_data = [ - np.random.random([2, 512, 7, 7]).astype("float32"), - np.random.random([2, 512, 7, 7]).astype("float32") + np.random.random([64, 64]).astype("float32"), + np.random.random([64, 64]).astype("float32") ] result = prog.test_benchmark( self.target, [a, b], tensor_data, c, 200, - "TESTING [elementwise_add] time cost with shape [2,512,7,7]...") + "TESTING [elementwise_add] time cost with shape [64, 64]...") + result = result.numpy(self.target).reshape(-1) + self.assertTrue( + np.allclose( + (tensor_data[0] + tensor_data[1]).reshape(-1), + result, + atol=1e-4)) def test_elementwise2(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([2, 512, 112, 112]) + b = Variable("B").set_type(Float(32)).set_shape([2, 512, 112, 112]) + c = prog.add(a, b) + tensor_data = [ + np.random.random([2, 512, 112, 112]).astype("float32"), + np.random.random([2, 512, 112, 112]).astype("float32") + ] + result = prog.test_benchmark( + self.target, [a, b], tensor_data, c, 200, + "TESTING [elementwise_add] time cost with shape [2, 512, 112, 112]..." + ) + result = result.numpy(self.target).reshape(-1) + self.assertTrue( + np.allclose( + (tensor_data[0] + tensor_data[1]).reshape(-1), + result, + atol=1e-4)) + + def atest_elementwise2(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([4, 1024]) b = Variable("B").set_type(Float(32)).set_shape([4, 1024]) @@ -99,52 +208,82 @@ def test_elementwise2(self): ] result = prog.test_benchmark_with_code( self.target, [a, b], tensor_data, c, 200, - "TESTING [elementwise_add] time cost with shape [4,1024]...", + "TESTING [elementwise_add] time cost with input code...", '''extern "C" { -#include "cinn_cuda_runtime_source.cuh" - -#ifdef __CUDACC_RTC__ -typedef int int32_t; -typedef char int8_t; -#endif - __global__ void fn_elementwise_add_0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ EleAdd_Out_0) { - EleAdd_Out_0[((1024 * ((int)blockIdx.x)) + ((int)threadIdx.x))] = (A[((1024 * ((int)blockIdx.x)) + ((int)threadIdx.x))] + B[((1024 * ((int)blockIdx.x)) + ((int)threadIdx.x))]); + EleAdd_Out_0[1024 * blockIdx.x + threadIdx.x] = (A[1024 * blockIdx.x + threadIdx.x] + B[1024 * blockIdx.x + threadIdx.x]); } }''') def test_batchnorm(self): prog = Program() - a = Variable("A").set_type(Float(32)).set_shape([2, 512, 7, 7]) + a = Variable("A").set_type(Float(32)).set_shape([2, 512, 32, 32]) b = Variable("B").set_type(Float(32)).set_shape([512]) c = Variable("C").set_type(Float(32)).set_shape([512]) d = Variable("D").set_type(Float(32)).set_shape([512]) e = Variable("E").set_type(Float(32)).set_shape([512]) f = prog.batchnorm(a, b, c, d, e, {}) tensor_data = [ - np.random.random([2, 512, 7, 7]).astype("float32"), + np.random.random([2, 512, 32, 32]).astype("float32"), np.random.random([512]).astype("float32"), np.random.random([512]).astype("float32"), np.random.random([512]).astype("float32"), np.random.random([512]).astype("float32") ] + result = prog.test_benchmark( + self.target, [a, b, c, d, e], tensor_data, f, 1000, + "TESTING [batchnorm] time cost with shape [2, 512, 32, 32]...") + + def atest_batchnorm2(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([2, 64, 8, 8]) + b = Variable("B").set_type(Float(32)).set_shape([64]) + c = Variable("C").set_type(Float(32)).set_shape([64]) + d = Variable("D").set_type(Float(32)).set_shape([64]) + e = Variable("E").set_type(Float(32)).set_shape([64]) + f = prog.batchnorm(a, b, c, d, e, {}) + tensor_data = [ + np.random.random([2, 64, 8, 8]).astype("float32"), + np.random.random([64]).astype("float32"), + np.random.random([64]).astype("float32"), + np.random.random([64]).astype("float32"), + np.random.random([64]).astype("float32") + ] result = prog.test_benchmark( self.target, [a, b, c, d, e], tensor_data, f, 200, - "TESTING [batchnorm] time cost with shape [2, 512, 7, 7]...") + "TESTING [batchnorm] time cost with shape [2, 64, 8, 8]...") + + def test_relu3(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([2, 512, 112, 112]) + c = prog.relu(a) + tensor_data = [np.random.random([2, 512, 112, 112]).astype("float32")] + result = prog.test_benchmark( + self.target, [a], tensor_data, c, 200, + "TESTING [relu] time cost with shape [2,512,112,112]...") def test_relu(self): prog = Program() a = Variable("A").set_type(Float(32)).set_shape([64, 64]) - c = prog.relu(a) + c = prog.sigmoid(a) tensor_data = [np.random.random([64, 64]).astype("float32")] result = prog.test_benchmark( self.target, [a], tensor_data, c, 200, - "TESTING [relu] time cost with shape [64,64]...") + "TESTING [sigmoid] time cost with shape [64,64]...") + + def test_relu2(self): + prog = Program() + a = Variable("A").set_type(Float(32)).set_shape([2, 512, 112, 112]) + c = prog.sigmoid(a) + tensor_data = [np.random.random([2, 512, 112, 112]).astype("float32")] + result = prog.test_benchmark( + self.target, [a], tensor_data, c, 200, + "TESTING [sigmoid] time cost with shape [2,512,112,112]...") if __name__ == "__main__": diff --git a/tools/tvm_benchmark/tvm_graph_with_single_op.py b/tools/tvm_benchmark/tvm_graph_with_single_op.py index c2239c449f91f..b09dd842a6afc 100755 --- a/tools/tvm_benchmark/tvm_graph_with_single_op.py +++ b/tools/tvm_benchmark/tvm_graph_with_single_op.py @@ -33,9 +33,99 @@ def get_network_conv2d(): return mod, params, input_shape, output_shape, input_names +def get_network_conv2d_resnet1(): + input_shape = [(2, 3, 224, 224), (64, 3, 7, 7)] + output_shape = (2, 64, 112, 112) + input_names = ["x", "y"] + x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) + y = relay.Var(input_names[1], tvm.relay.TensorType(input_shape[1])) + print("[Test]Begin building graph with op relay.nn.conv2d resnet1") + mod = relay.Function([x, y], + relay.nn.conv2d( + x, + y, + kernel_size=(7, 7), + padding=(3, 3), + strides=(2, 2))) + params = [] + return mod, params, input_shape, output_shape, input_names + + +def get_network_conv2d_resnet2(): + input_shape = [(2, 64, 56, 56), (64, 64, 3, 3)] + output_shape = (2, 64, 56, 56) + input_names = ["x", "y"] + x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) + y = relay.Var(input_names[1], tvm.relay.TensorType(input_shape[1])) + print("[Test]Begin building graph with op relay.nn.conv2d resnet2") + mod = relay.Function([x, y], + relay.nn.conv2d( + x, + y, + kernel_size=(3, 3), + padding=(1, 1), + strides=(1, 1))) + params = [] + return mod, params, input_shape, output_shape, input_names + + +def get_network_conv2d_resnet3(): + input_shape = [(2, 64, 56, 56), (64, 64, 1, 1)] + output_shape = (2, 64, 56, 56) + input_names = ["x", "y"] + x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) + y = relay.Var(input_names[1], tvm.relay.TensorType(input_shape[1])) + print("[Test]Begin building graph with op relay.nn.conv2d resnet2") + mod = relay.Function([x, y], + relay.nn.conv2d( + x, + y, + kernel_size=(1, 1), + padding=(0, 0), + strides=(1, 1))) + params = [] + return mod, params, input_shape, output_shape, input_names + + +def get_network_conv2d_resnet4(): + input_shape = [(2, 64, 56, 56), (128, 64, 1, 1)] + output_shape = (2, 128, 28, 28) + input_names = ["x", "y"] + x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) + y = relay.Var(input_names[1], tvm.relay.TensorType(input_shape[1])) + print("[Test]Begin building graph with op relay.nn.conv2d resnet2") + mod = relay.Function([x, y], + relay.nn.conv2d( + x, + y, + kernel_size=(1, 1), + padding=(0, 0), + strides=(2, 2))) + params = [] + return mod, params, input_shape, output_shape, input_names + + +def get_network_conv2d_resnet5(): + input_shape = [(2, 128, 28, 28), (256, 128, 3, 3)] + output_shape = (2, 256, 14, 14) + input_names = ["x", "y"] + x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) + y = relay.Var(input_names[1], tvm.relay.TensorType(input_shape[1])) + print("[Test]Begin building graph with op relay.nn.conv2d resnet2") + mod = relay.Function([x, y], + relay.nn.conv2d( + x, + y, + kernel_size=(3, 3), + padding=(1, 1), + strides=(2, 2))) + params = [] + return mod, params, input_shape, output_shape, input_names + + def get_network_relu(): - input_shape = [(1024, 7)] - output_shape = (1024, 7) + input_shape = [(2, 512, 112, 112)] + output_shape = (2, 512, 112, 112) input_names = ["x"] x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) print("[Test]Begin building graph with op relay.nn.relu") @@ -45,8 +135,8 @@ def get_network_relu(): def get_network_elementwise(): - input_shape = [(4, 1024), (4, 1024)] - output_shape = (4, 1024) + input_shape = [(64, 64), (64, 64)] + output_shape = (64, 64) input_names = ["x", "y"] x = relay.Var(input_names[0], tvm.relay.TensorType(input_shape[0])) y = relay.Var(input_names[1], tvm.relay.TensorType(input_shape[1])) @@ -96,14 +186,14 @@ def get_network_pool2d(): def get_network_batchnorm(): - data0 = relay.var("data0", relay.TensorType((2, 512, 7, 7), "float32")) + data0 = relay.var("data0", relay.TensorType((2, 512, 32, 32), "float32")) bn_gamma = relay.var("bn_gamma1", relay.TensorType((512, ), "float32")) bn_beta = relay.var("bn_beta1", relay.TensorType((512, ), "float32")) bn_mmean = relay.var("bn_mean1", relay.TensorType((512, ), "float32")) bn_mvar = relay.var("bn_var1", relay.TensorType((512, ), "float32")) bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] - input_shape = [(2, 512, 7, 7), (512), (512), (512), (512)] - output_shape = (2, 512, 7, 7) + input_shape = [(2, 512, 32, 32), (512), (512), (512), (512)] + output_shape = (2, 512, 32, 32) input_names = ["data0", "bn_gamma1", "bn_beta1", "bn_mean1", "bn_var1"] print("[Test]Begin building graph with op relay.nn.batch_norm") mod = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn) @@ -136,9 +226,8 @@ def tune_and_evaluate(func): module.set_input(input_names[index], data_temp) # evaluate evaluator_preheat = module.module.time_evaluator( - "run", ctx, number=50, repeat=50) - evaluator = module.module.time_evaluator( - "run", ctx, number=500, repeat=100) + "run", ctx, number=10, repeat=10) + evaluator = module.module.time_evaluator("run", ctx, number=100, repeat=10) prof_res1 = np.array( evaluator_preheat().results) * 1000 # convert to millisecond @@ -150,10 +239,15 @@ def tune_and_evaluate(func): (np.mean(prof_res2), np.std(prof_res2))) -#tune_and_evaluate(get_network_conv2d) #tune_and_evaluate(get_network_pool2d) #tune_and_evaluate(get_network_softmax) #tune_and_evaluate(get_network_matmul) #tune_and_evaluate(get_network_batchnorm) tune_and_evaluate(get_network_relu) -tune_and_evaluate(get_network_elementwise) +#tune_and_evaluate(get_network_elementwise) +#tune_and_evaluate(get_network_conv2d_resnet1) +#tune_and_evaluate(get_network_conv2d_resnet2) +#tune_and_evaluate(get_network_conv2d_resnet3) +#tune_and_evaluate(get_network_conv2d_resnet4) +#tune_and_evaluate(get_network_conv2d_resnet5) +#tune_and_evaluate(get_network_conv2d)