Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cinn(backends): generate infer shape kernel to infer shape of output tensor #60519

Merged
merged 8 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,23 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
[](auto& arg) { return std::addressof(arg); });
// @}

// Set local scope table
CHECK_EQ(ll_function_args.size(), func->args.size());
for (int i = 0; i < ll_function_args.size(); ++i) {
SetVar(func->args[i].name(), ll_function_args[i]);
}
llvm::BasicBlock* entry = llvm::BasicBlock::Create(
/*Context=*/b_->getContext(),
/*Name=*/"entry",
/*Parent=*/f_,
/*InsertBefore=*/nullptr);
b_->SetInsertPoint(entry);
CodeGenLLVM::Visit(&func->body);

// Reset local scope table
for (const ir::Argument& func_arg : func->args) {
symbol_table_->Erase(func_arg.name());
}
RetVoid();

return f_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
} else if (op->name == runtime::intrinsic::call_cuda_kernel) {
return LowerCUDAKernelCall(op);
} else {
CINN_NOT_IMPLEMENTED;
return CodeGenLLVM::Visit(op);
}
}

Expand Down
24 changes: 23 additions & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace backends {
#define KERNEL_ARGS "kernel_args"
#define KERNEL_ARGS_NUM "kernel_args_num"
#define KERNEL_STREAM "kernel_stream"
#define TENSOR_SHAPE_ARGS "tensor_shape_args"

/**
* Split a CINN Module into two separate modules, one cantains the host
Expand Down Expand Up @@ -150,7 +151,8 @@ struct CollectBucketStrategyHostFunctionVisitor
: CollectHostFunctionVisitor(module_name),
kernel_args_(KERNEL_ARGS, type_of<void*>()),
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
kernel_stream_(KERNEL_STREAM, type_of<void*>()) {}
kernel_stream_(KERNEL_STREAM, type_of<void*>()),
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t**>()) {}

std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down Expand Up @@ -181,6 +183,25 @@ struct CollectBucketStrategyHostFunctionVisitor
{});
host_module_builder.AddFunctionWithoutOptim(
host_func.as_lowered_func_ref());

// Parse LoweredFunc to infer output tensor's shape
std::vector<ir::Expr> infer_shape_func_body_stmts(arg_defs_);
infer_shape_func_body_stmts.insert(
infer_shape_func_body_stmts.end(),
op->infer_shape_func.as_lowered_func()->body);

std::vector<ir::Argument> infer_shape_arguments = {
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
ir::Argument(tensor_shape_args_, ir::Argument::IO::kOutput)};

ir::Expr host_infer_shape_func =
ir::_LoweredFunc_::Make(op->infer_shape_func.as_lowered_func()->name,
infer_shape_arguments,
ir::Block::Make(infer_shape_func_body_stmts),
{});
host_module_builder.AddFunctionWithoutOptim(
host_infer_shape_func.as_lowered_func_ref());
}

void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);
Expand All @@ -199,6 +220,7 @@ struct CollectBucketStrategyHostFunctionVisitor
ir::Var kernel_args_;
ir::Var kernel_args_num_;
ir::Var kernel_stream_;
ir::Var tensor_shape_args_;
};

} // namespace detail
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
// TODO(fc500110) hard coding
if (LLVM_WillVarLowerAsPointer(op->name)) {
result = value;
} else if (value->getType()->isPointerTy()) {
} else if (value->getType()->isPointerTy() &&
!value->getType()->getPointerElementType()->isPointerTy()) {
result = Load(value, op->name + "_load");
} else {
result = value;
Expand Down
12 changes: 12 additions & 0 deletions paddle/cinn/common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,18 @@ inline Type type_of<uint8_t*>() {
return x;
}
template <>
inline Type type_of<int32_t*>() {
Type x = Int(32);
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<int32_t**>() {
Type x = Int(32);
x.set_cpp_handle2();
return x;
}
template <>
inline Type type_of<void*>() {
Type x = type_of<void>();
x.set_cpp_handle();
Expand Down
11 changes: 6 additions & 5 deletions paddle/cinn/hlir/framework/op_lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ class OpLowerer {
group, apply_op_schedule, apply_group_schedule, apply_pass);
}

std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>> BucketLower(
const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
std::vector<
std::pair<ir::SymbolicPredicate, pir::OpLowererImpl::WrapLoweredFunc>>
BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
bool apply_pass = true) {
return impl_->BucketLower(
group, apply_op_schedule, apply_group_schedule, apply_pass);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
bool apply_group_schedule = true,
bool apply_pass = true);

std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>> BucketLower(
std::vector<std::pair<ir::SymbolicPredicate, WrapLoweredFunc>> BucketLower(
const GroupPtr& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
Expand Down
9 changes: 8 additions & 1 deletion paddle/cinn/hlir/framework/op_lowering_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ namespace framework {
template <typename T>
class OpLowererImplBase {
public:
struct WrapLoweredFunc {
ir::LoweredFunc kernel_func;
ir::LoweredFunc infer_shape_func;
WrapLoweredFunc(ir::LoweredFunc kernel_func,
ir::LoweredFunc infer_shape_func = ir::LoweredFunc())
: infer_shape_func(infer_shape_func), kernel_func(kernel_func) {}
};
OpLowererImplBase() = default;
~OpLowererImplBase() = default;

Expand All @@ -38,7 +45,7 @@ class OpLowererImplBase {
bool apply_group_schedule = true,
bool apply_pass = true) = 0;

virtual std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>>
virtual std::vector<std::pair<ir::SymbolicPredicate, WrapLoweredFunc>>
BucketLower(const T& group,
bool apply_op_schedule = false,
bool apply_group_schedule = true,
Expand Down
22 changes: 17 additions & 5 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/cinn/hlir/framework/pir/compilation_task.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/module.h"

Expand All @@ -23,11 +24,14 @@ namespace hlir {
namespace framework {

void GroupCompilationContext::SetLoweredFuncs(
std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>>&& funcs) {
for (std::pair<ir::SymbolicPredicate, ir::LoweredFunc>& predicate2func :
funcs) {
std::vector<std::pair<ir::SymbolicPredicate,
pir::OpLowererImpl::WrapLoweredFunc>>&& funcs) {
for (std::pair<ir::SymbolicPredicate, pir::OpLowererImpl::WrapLoweredFunc>&
predicate2func : funcs) {
predicates_.push_back(predicate2func.first);
lowered_funcs_.push_back(predicate2func.second);
lowered_funcs_.push_back(predicate2func.second.kernel_func);
infer_shape_lowered_funcs_.push_back(
predicate2func.second.infer_shape_func);
++func_size_;
}
}
Expand Down Expand Up @@ -67,12 +71,13 @@ void CompilationTask::CodegenAndJit() {
ir::Module::Builder builder(cinn::common::UniqName("module"),
context_->target_);
CHECK_EQ(context_->predicates_.size(), context_->lowered_funcs_.size());
for (const ir::Expr predicate : context_->predicates_) {
for (const ir::Expr& predicate : context_->predicates_) {
builder.AddPredicate(predicate);
}
for (const ir::LoweredFunc& func : context_->lowered_funcs_) {
builder.AddFunction(func);
}
builder.AddInferShapeFunc(context_->infer_shape_lowered_funcs_[0]);
ir::Module ir_module = builder.Build();

context_->backend_compiler_ = backends::Compiler::Create(context_->target_);
Expand All @@ -90,6 +95,9 @@ std::unique_ptr<Instruction> CompilationTask::BuildInstruction() {
VLOG(4) << "Lookup kernel name: " << fn_name;
auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name);
CHECK(fn_ptr);
auto* infer_shape_fn_ptr =
context_->backend_compiler_->Lookup(fn_name + "_infer_shape" + fn_name);
CHECK(infer_shape_fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), fn_name);
instr->Finalize();
return instr;
Expand All @@ -100,8 +108,12 @@ pir::CINNKernelInfo CompilationTask::BuildPirCINNKernelInfo() {
VLOG(4) << "Lookup kernel name: " << fn_name;
auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name);
CHECK(fn_ptr);
auto* infer_shape_fn_ptr =
context_->backend_compiler_->Lookup(fn_name + "_infer_shape");
CHECK(infer_shape_fn_ptr);
pir::CINNKernelInfo cinn_kernel_info;
cinn_kernel_info.fn_ptr = fn_ptr;
cinn_kernel_info.infer_shape_fn_ptr = infer_shape_fn_ptr;
cinn_kernel_info.int_args_map = context_->group_->int_args_map;
return cinn_kernel_info;
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/framework/pir/compilation_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class GroupCompilationContext {
: target_(target), group_(group), scope_(scope) {}

void SetLoweredFuncs(
std::vector<std::pair<ir::SymbolicPredicate, ir::LoweredFunc>>&& funcs);
std::vector<std::pair<ir::SymbolicPredicate,
pir::OpLowererImpl::WrapLoweredFunc>>&& funcs);
std::string PrintPredicate2Funcs() const;
void* FuncPtr();
std::shared_ptr<backends::Compiler> BackendCompiler();
Expand All @@ -47,6 +48,7 @@ class GroupCompilationContext {
size_t func_size_ = 0;
std::vector<ir::SymbolicPredicate> predicates_;
std::vector<ir::LoweredFunc> lowered_funcs_;
std::vector<ir::LoweredFunc> infer_shape_lowered_funcs_;
std::string host_func_name_;
std::string host_code_;
std::vector<std::string> device_code_;
Expand Down
Loading