Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_var…
Browse files Browse the repository at this point in the history
…map_save
  • Loading branch information
Your Name committed Jan 8, 2024
2 parents 1e6a7d3 + 5bb661d commit 98d62e4
Show file tree
Hide file tree
Showing 767 changed files with 30,256 additions and 8,253 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*
Expand Down
4 changes: 2 additions & 2 deletions cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ set(XPU_XBLAS_LIB_NAME "libxpu_blas.so")
set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231203")
set(XPU_BASE_DATE "20231218")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231226")
set(XPU_XHPC_BASE_DATE "20231229")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
13 changes: 11 additions & 2 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,19 @@ copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)

copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/
)

copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/
)

copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h
Expand Down
10 changes: 10 additions & 0 deletions paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,23 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
[](auto& arg) { return std::addressof(arg); });
// @}

// Set local scope table
CHECK_EQ(ll_function_args.size(), func->args.size());
for (int i = 0; i < ll_function_args.size(); ++i) {
SetVar(func->args[i].name(), ll_function_args[i]);
}
llvm::BasicBlock* entry = llvm::BasicBlock::Create(
/*Context=*/b_->getContext(),
/*Name=*/"entry",
/*Parent=*/f_,
/*InsertBefore=*/nullptr);
b_->SetInsertPoint(entry);
CodeGenLLVM::Visit(&func->body);

// Reset local scope table
for (const ir::Argument& func_arg : func->args) {
symbol_table_->Erase(func_arg.name());
}
RetVoid();

return f_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
} else if (op->name == runtime::intrinsic::call_cuda_kernel) {
return LowerCUDAKernelCall(op);
} else {
CINN_NOT_IMPLEMENTED;
return CodeGenLLVM::Visit(op);
}
}

Expand Down
24 changes: 23 additions & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace backends {
#define KERNEL_ARGS "kernel_args"
#define KERNEL_ARGS_NUM "kernel_args_num"
#define KERNEL_STREAM "kernel_stream"
#define TENSOR_SHAPE_ARGS "tensor_shape_args"

/**
* Split a CINN Module into two separate modules, one cantains the host
Expand Down Expand Up @@ -150,7 +151,8 @@ struct CollectBucketStrategyHostFunctionVisitor
: CollectHostFunctionVisitor(module_name),
kernel_args_(KERNEL_ARGS, type_of<void*>()),
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
kernel_stream_(KERNEL_STREAM, type_of<void*>()) {}
kernel_stream_(KERNEL_STREAM, type_of<void*>()),
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t**>()) {}

std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down Expand Up @@ -181,6 +183,25 @@ struct CollectBucketStrategyHostFunctionVisitor
{});
host_module_builder.AddFunctionWithoutOptim(
host_func.as_lowered_func_ref());

// Parse LoweredFunc to infer output tensor's shape
std::vector<ir::Expr> infer_shape_func_body_stmts(arg_defs_);
infer_shape_func_body_stmts.insert(
infer_shape_func_body_stmts.end(),
op->infer_shape_func.as_lowered_func()->body);

std::vector<ir::Argument> infer_shape_arguments = {
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
ir::Argument(tensor_shape_args_, ir::Argument::IO::kOutput)};

ir::Expr host_infer_shape_func =
ir::_LoweredFunc_::Make(op->infer_shape_func.as_lowered_func()->name,
infer_shape_arguments,
ir::Block::Make(infer_shape_func_body_stmts),
{});
host_module_builder.AddFunctionWithoutOptim(
host_infer_shape_func.as_lowered_func_ref());
}

void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);
Expand All @@ -199,6 +220,7 @@ struct CollectBucketStrategyHostFunctionVisitor
ir::Var kernel_args_;
ir::Var kernel_args_num_;
ir::Var kernel_stream_;
ir::Var tensor_shape_args_;
};

} // namespace detail
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
// TODO(fc500110) hard coding
if (LLVM_WillVarLowerAsPointer(op->name)) {
result = value;
} else if (value->getType()->isPointerTy()) {
} else if (value->getType()->isPointerTy() &&
!value->getType()->getPointerElementType()->isPointerTy()) {
result = Load(value, op->name + "_load");
} else {
result = value;
Expand Down
6 changes: 5 additions & 1 deletion paddle/cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ gather_srcs(
nvgpu_dev_info.cc
integer_set.cc
dim_expr_simplify.cc
dim_expr_converter.cc)
dim_expr_converter.cc
broadcast_tree.cc
dim_expr_util.cc)

cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
Expand All @@ -48,8 +50,10 @@ if(WITH_CUDA)
gtest glog)
endif()
if(NOT CINN_ONLY)
cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore)
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
cinncore)
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
cinncore)
cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore)
endif()
Loading

0 comments on commit 98d62e4

Please sign in to comment.