Skip to content

Commit

Permalink
Add mulbias and change mul's behavior (PaddlePaddle#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Dec 1, 2020
1 parent 92f72ad commit ad6e41c
Show file tree
Hide file tree
Showing 21 changed files with 635 additions and 78 deletions.
Empty file modified cinn/backends/raw_cuda_code_test.cu
100644 → 100755
Empty file.
84 changes: 80 additions & 4 deletions cinn/frontend/paddle_model_to_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@ namespace frontend {
using utils::Join;
using utils::TransValidVarName;

void MoveData(float* data, int i, int M, int N) {
float temp = data[i];
int cur = i; // current data index
int pre = (cur % M) * N + cur / M;
while (pre != i) {
data[cur] = data[pre];
cur = pre;
pre = (cur % M) * N + cur / M;
}
data[cur] = temp;
}

void TransposeData(float* data, int M, int N) {
for (int i = 0; i < M * N; i++) {
int next = (i % N) * M + i / N;
while (next > i) // next < 1 implies duplicate
next = (next % N) * M + next / N;
if (next == i) // process current ring
MoveData(data, i, M, N);
}
}

void PaddleModelToProgram::AddOpMapper_feed() {
op_mappers_["feed"] = [&](const paddle::cpp::OpDesc& op_desc) {
auto outs = op_desc.Output("Out");
Expand Down Expand Up @@ -65,11 +87,13 @@ void PaddleModelToProgram::AddOpMapper_mul() {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
auto x = GetVar(utils::TransValidVarName(x_name));
auto y_name = op_desc.Input("Y").front();
auto x = GetVar(utils::TransValidVarName(x_name));
TransposeVar(TransValidVarName(y_name));
auto y = GetVar(utils::TransValidVarName(y_name));
int x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims");
int y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims");
CHECK_EQ(y_num_col_dims, 1) << "The y_num_col_dims of mul is not 1! Please check.";
VLOG(4) << "Mul x_num_col_dims: " << x_num_col_dims;
VLOG(4) << "Mul y_num_col_dims: " << y_num_col_dims;
VLOG(4) << "x shape: " << utils::Join(x->shape, ",");
Expand Down Expand Up @@ -364,6 +388,56 @@ void PaddleModelToProgram::AddOp(const paddle::cpp::OpDesc& op_desc) {
LOG(FATAL) << "Not supported op [" << op_desc.Type() << "] found";
}

void PaddleModelToProgram::TransposeVar(const std::string& name) {
CheckVarNameValid(name);
auto* var = scope_->FindVar(name);
if (var) {
auto& tensor = std::get<hlir::framework::Tensor>(*var);
if (target_.arch == Target::Arch::X86) {
float* data = tensor->mutable_data<float>(target_);
CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check.";
TransposeData(data, tensor->shape().data()[0], tensor->shape().data()[1]);
} else if (target_.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
std::vector<float> data(tensor->shape().numel());
CUDA_CALL(cudaMemcpy(data.data(),
reinterpret_cast<void*>(tensor->mutable_data<float>(target_)),
tensor->shape().numel() * sizeof(float),
cudaMemcpyDeviceToHost));
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif
CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check.";

TransposeData(data.data(), tensor->shape().data()[0], tensor->shape().data()[1]);

#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(tensor->mutable_data<float>(target_)),
data.data(),
tensor->shape().numel() * sizeof(float),
cudaMemcpyHostToDevice));
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif

} else {
CINN_NOT_IMPLEMENTED
}

Variable var;
var.set_id(name);
std::vector<int> reverse_shape = tensor->shape().data();
std::reverse(reverse_shape.begin(), reverse_shape.end());
tensor->shape().SetData(reverse_shape);
var->shape = tensor->shape().data();
// TODO(Superjomn) Make this determined by model.
var->type = Float(32);
AddVar(name, var, true);
} else {
LOG(FATAL) << "No var called [" << name << "] exists";
}
}

Variable PaddleModelToProgram::GetVar(const std::string& name) {
CheckVarNameValid(name);

Expand Down Expand Up @@ -399,9 +473,11 @@ std::unique_ptr<Program> PaddleModelToProgram::operator()(const std::string& mod
return std::move(program_);
}

void PaddleModelToProgram::AddVar(const std::string& name, const Variable& var) {
void PaddleModelToProgram::AddVar(const std::string& name, const Variable& var, bool replace) {
CheckVarNameValid(name);
CHECK(!var_map_.count(name)) << "Duplicate variable [" << name << "] found";
if (replace == false) {
CHECK(!var_map_.count(name)) << "Duplicate variable [" << name << "] found";
}
var_map_[name] = var;
}

Expand Down
5 changes: 4 additions & 1 deletion cinn/frontend/paddle_model_to_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <variant>
#include <vector>

#include "cinn/backends/cuda_util.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/object.h"
Expand Down Expand Up @@ -73,10 +74,12 @@ class PaddleModelToProgram {
const std::unordered_map<std::string, std::string>& var_model_to_program_map() { return var_model_to_program_map_; }

protected:
void AddVar(const std::string& name, const Variable& var);
void AddVar(const std::string& name, const Variable& var, bool replace = false);

Variable GetVar(const std::string& name);

void TransposeVar(const std::string& name);

private:
std::unordered_map<std::string, std::function<void(const paddle::cpp::OpDesc&)>> op_mappers_;
std::unique_ptr<Program> program_;
Expand Down
9 changes: 9 additions & 0 deletions cinn/frontend/syntax.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ Variable Program::mul(const Variable& a, const Variable& b, int x_num_col_dims,
return instr.GetOutput(0);
}

Variable Program::mulbias(
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims, int y_num_col_dims) {
Instruction instr("mulbias", {a, b, c});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
AppendInstruction(instr);
return instr.GetOutput(1);
}

std::string _Instruction_::debug_string() const {
struct Visit {
std::stringstream& s_;
Expand Down
6 changes: 6 additions & 0 deletions cinn/frontend/syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ struct Program {
*/
Variable mul(const Variable& a, const Variable& b, int x_num_col_dims = 1, int y_num_col_dims = 1);

/**
* Multiply two matrix and add a bias.
*/
Variable mulbias(
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims = 1, int y_num_col_dims = 1);

/**
* Add two tensors element-wise.
*/
Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/syntax_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ TEST(syntax, program_execute_fc) {
const int N = 24;

Placeholder a(Float(32), {B, M, K}, "A");
Placeholder w(Float(32), {K, N}, "W"); // weight
Placeholder w(Float(32), {N, K}, "W"); // weight
Placeholder b(Float(32), {N}, "B"); // bias

Program program;
Expand Down
4 changes: 2 additions & 2 deletions cinn/hlir/framework/cuda_graph_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ std::vector<float> test_mul(const std::vector<float>& A, const std::vector<float
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
for (int k = 0; k < K; k++) {
C_target[i * N + j] += A[i * K + k] * B[k * N + j];
C_target[i * N + j] += A[i * K + k] * B[j * N + k];
}
}
}
Expand Down Expand Up @@ -89,7 +89,7 @@ TEST(GraphCompiler, RunModel) {
frontend::Variable b("B");
Type t = Float(32);
a->shape = {M.as_int32(), K.as_int32()};
b->shape = {K.as_int32(), N.as_int32()};
b->shape = {N.as_int32(), K.as_int32()};
a->type = t;
b->type = t;
auto c = prog.mul(a, b);
Expand Down
22 changes: 19 additions & 3 deletions cinn/hlir/framework/graph_compiler.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
#include <vector>

#include "cinn/backends/compiler.h"
#include "cinn/backends/cuda_util.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/instruction.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/framework/scope.h"
#include "cinn/ir/lowered_func.h"
#include "cinn/lang/packed_func.h"
#include "cinn/utils/timer.h"

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -47,14 +49,28 @@ class Program {
VLOG(3) << out << " ";
}
ins->Run();
CUDA_CALL(cudaDeviceSynchronize());
}
}

void ExecuteTest(int repeat_) {
CHECK_EQ(instrs_.size(), 1);
for (auto& ins : instrs_) {
ins->RunTest(repeat_);
cinn::utils::Timer timer1;
for (int i = 0; i < 100; i++) {
for (auto& ins : instrs_) {
ins->RunTest(repeat_);
}
}
timer1.Start();
for (int i = 0; i < repeat_; i++) {
for (auto& ins : instrs_) {
ins->RunTest(repeat_);
}
}

CUDA_CALL(cudaDeviceSynchronize());
double test_op_time = timer1.Stop() / repeat_;

LOG(INFO) << "Repeat times: [" << repeat_ << "], average op time: [" << test_op_time << "] ms";
}
/**
* Get the number of instructions.
Expand Down
12 changes: 2 additions & 10 deletions cinn/hlir/framework/instruction.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <string>
#include <vector>

#include "cinn/backends/cuda_util.h"
#include "cinn/common/test_helper.h"
#include "cinn/hlir/framework/scope.h"
#include "cinn/utils/timer.h"
Expand Down Expand Up @@ -43,16 +44,7 @@ class Instruction {
void RunTest(int repeat_) {
CHECK(fn_) << "The LoweredFunc address should be set first by calling SetLoweredFunc method";
auto& pod_args = PreparePodArgs();
cinn::utils::Timer timer;
for (int i = 0; i < 100; i++) {
fn_(pod_args.data(), pod_args.size());
}
timer.Start();
for (int i = 0; i < repeat_; i++) {
fn_(pod_args.data(), pod_args.size());
}
double test_op_time = timer.Stop() / repeat_;
LOG(INFO) << "Repeat times: [" << repeat_ << "], average op run time: [" << test_op_time << "] ms";
fn_(pod_args.data(), pod_args.size());
}

/**
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class _Tensor_ : public Object {
public:
_Tensor_() : buffer_(std::make_shared<Buffer>()) {}

const Shape& shape() const { return shape_; }
Shape& shape() { return shape_; }

void Resize(const Shape& shape) {
shape_ = shape;
Expand Down
6 changes: 3 additions & 3 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
if (target.arch == Target::Arch::NVGPU) {
Expr Out = arg_pack[2];
CHECK(Out.as_tensor());
// pe::CudaScheduleConv(stages, input_pad.as_tensor_ref(), weights_dilation.as_tensor_ref(), Out.as_tensor_ref(),
// target);
stages[Out.as_tensor_ref()]->Split(1, 2);
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
Expand Down Expand Up @@ -440,9 +442,7 @@ std::shared_ptr<OpStrategy> StrategyForBatchNorm(const framework::NodeAttr &attr
Expr Out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(Out.as_tensor());
pe::CudaSplitSchedule(stages[Out.as_tensor_ref()], output_shapes.back());
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target);
}
*ret = arg_pack;
});
Expand Down
Loading

0 comments on commit ad6e41c

Please sign in to comment.