Skip to content

Commit

Permalink
[pir] [run program op] Ir run program node prune (PaddlePaddle#57566)
Browse files Browse the repository at this point in the history
* support build model in python

* fix ci bugs

* fix ci bugs

* fix compile bugs

* fix ci bugs

* add infermeta for data

* fix ci bugs

* fix ci bugs

* fix ci bugs

* fix bugs when run ir program mutiple times

* perfect code

* frontend demo debugging

* support program split and go into run program node.

* simple run the dy2static test in newir_api mode.

* remove frame.proto changes

* merge

* fix ir-run-program-node

* fix some code

* fix output error

* fix some errors

* fix

* fix

* fix

* fix conflict

* fix files

* fix some errors

* support backward inputs prune

* support new ir prune in run_program_ad_func

* fix conflict

* tmp commit

* support backward of run program node in nn.Layer

* fix unittest problem

* filter inputs and outputs

* fix grad changes

* fix error when merge

* fix

* success run gpt forward and backward.

* consider no need buffer to reduce gpu memory usage

* merge

* fix concat case.py

* fix concat op

* fix

* fix

* fix mac compile problem

* fix compile error

---------

Co-authored-by: YuanRisheng <[email protected]>
  • Loading branch information
2742195759 and YuanRisheng authored Sep 26, 2023
1 parent 522bc3a commit f025068
Show file tree
Hide file tree
Showing 16 changed files with 517 additions and 191 deletions.
52 changes: 36 additions & 16 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/value.h"

// Filter params without grads in global block. In this case, we will
// tag its AutogradMeta with stop_gradient = True to avoid fault from
Expand Down Expand Up @@ -90,6 +92,23 @@ static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
return filter_x;
}

static std::vector<paddle::Tensor> newir_filter_unused_input_var_in_backward(
const std::vector<paddle::Tensor>& x,
const std::string x_key_name,
const paddle::framework::AttributeMap& attrs) {
auto values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at(x_key_name));
auto filter_x = std::vector<paddle::Tensor>(x);
for (size_t i = 0; i < x.size(); i++) {
if (values[i].impl() == nullptr) {
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
fake.set_name(paddle::framework::kFakeVarName);
filter_x[i] = fake;
}
}
return filter_x;
}

static std::vector<paddle::Tensor> Trans2ContiguousTensors(
const std::vector<paddle::Tensor>& tensors) {
std::vector<paddle::Tensor> res;
Expand Down Expand Up @@ -243,8 +262,17 @@ inline void newir_run_program_ad_func(
paddle::Tensor(std::make_shared<phi::DenseTensor>());
middles.push_back(&grad_node->GetMiddle()[i]);
}

auto backward_outs =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo"));
for (size_t i = 0; i < output_size; ++i) {
grad_node->GetOutputs()[i] = *out[i];
if (backward_outs[i] != nullptr) {
grad_node->GetOutputs()[i] = *out[i];
} else { // not used by backward program
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
fake.set_name(paddle::framework::kFakeVarName);
grad_node->GetOutputs()[i] = fake;
}
}
}

Expand All @@ -253,35 +281,26 @@ inline void newir_run_program_ad_func(
NewIRRunProgramAPI(
x, params, out, middles, step_scope, dout, require_any_grad, attrs);
if (require_any_grad) {
// auto x_names =
// PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));

egr::EagerUtils::PassStopGradient(false, &p_autograd_outs);

// Set Attributes
grad_node->SetAttrMap(attrs);

// auto* forward_global_block = PADDLE_GET_CONST(
// paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
// auto* backward_global_block = PADDLE_GET_CONST(
// paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars
// auto filter_x =
// filter_unused_input_var_in_backward(x, x_names, backward_global_block);
auto filter_x = newir_filter_unused_input_var_in_backward(x, "bx", attrs);
// Set TensorWrappers
grad_node->SetFwdX(x);
// Clear unused out vars
// clear_unused_out_var_in_backward(out, backward_global_block,
// step_scope[0]);
grad_node->SetFwdX(filter_x);

auto filter_params =
newir_filter_unused_input_var_in_backward(params, "bp", attrs);
grad_node->SetFwdParams(filter_params);

grad_node->SetFwdParams(params);
grad_node->SetStepScope(step_scope); // just for set useable.

// Set Grad out rank as same as fwd input and set stop gradient to bwd
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.

// TODO(@xiongkun): rewrite by new ir representation.
std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
x_require_grad.push_back(&x[i]);
Expand All @@ -290,6 +309,7 @@ inline void newir_run_program_ad_func(
grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);

// TODO(@xiongkun): rewrite by new ir representation.
// VLOG(2) << "clear_no_grad_edges.";
// clear_no_grad_edges_with_partial_block(params,
// forward_global_block,
Expand Down
113 changes: 55 additions & 58 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,6 @@ static void CheckInputVarStatus(const Tensor &tensor) {
"RunProgram(Grad)Op holds "
"wrong type. Expect type is DenseTensor.",
tensor.name()));

PADDLE_ENFORCE_EQ(
static_cast<phi::DenseTensor *>(tensor.impl().get())->IsInitialized(),
true,
paddle::platform::errors::InvalidArgument(
"The tensor in input tensor %s of "
"RunProgram(Grad)Op "
"is not initialized.",
tensor.name()));
}

static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
Expand All @@ -117,13 +108,6 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
"RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is DenseTensor",
name));
PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(),
true,
paddle::platform::errors::InvalidArgument(
"The tensor in output tensor %s get from "
"RunProgram(Grad)Op's internal "
"scope is not initialized.",
name));
} else if (dst_tensor.is_selected_rows()) {
auto &src_tensor = src_var.Get<phi::SelectedRows>();
PADDLE_ENFORCE_EQ(phi::SelectedRows::classof(&src_tensor),
Expand All @@ -133,14 +117,6 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
"RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is SelectedRows",
name));
PADDLE_ENFORCE_EQ(src_tensor.initialized(),
true,
paddle::platform::errors::InvalidArgument(
"The tensor in output tensor %s get from "
"RunProgram(Grad)Op's "
"internal scope is not initialized.",
name));

} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The RunProgram(Grad)Op only support output "
Expand Down Expand Up @@ -214,14 +190,23 @@ static auto GetNameFromValue(const ::pir::Block *block,
.dyn_cast<pir::StrAttribute>()
.AsString();
value2name[op->operand(0).source()] = name;
} else if (op->name() == "builtin.get_parameter") {
name = op->attributes()
.at("parameter_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
value2name[op->result(0).Value::impl()] = name;
}
}
std::vector<std::string> names;
std::transform(
values.begin(),
values.end(),
std::back_inserter(names),
[&value2name](const ::pir::Value &v) { return value2name[v]; });
std::transform(values.begin(),
values.end(),
std::back_inserter(names),
[&value2name](const ::pir::Value &v) {
if (!value2name.count(v))
return std::string(paddle::framework::kFakeVarName);
return value2name.at(v);
});
return names;
}

Expand Down Expand Up @@ -255,7 +240,7 @@ static void ShareTensorsFromScope(
auto &src_tensor = var->Get<phi::DenseTensor>();
auto *dst_tensor = const_cast<phi::DenseTensor *>(
dynamic_cast<const phi::DenseTensor *>(tensors[i]->impl().get()));
VLOG(2) << "share " << name << " from scope";
VLOG(4) << "share " << name << " from scope";
*dst_tensor = src_tensor;
} else if (var->IsType<phi::SelectedRows>()) {
auto &src_tensor = var->Get<phi::SelectedRows>();
Expand All @@ -272,6 +257,11 @@ static void ShareTensorsIntoScopeByValue(
const std::vector<::pir::Value> &values,
paddle::framework::Scope *scope) {
auto names = GetNameFromValue(block, values);
if (VLOG_IS_ON(4)) {
for (auto &s : names) {
VLOG(4) << "ShareTensorIntoScopeByValue name: " << s;
}
}
ShareTensorsIntoScopeWithName(tensors, names, scope);
}

Expand Down Expand Up @@ -461,8 +451,6 @@ inline void NewIRRunProgramAPI(
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm"));
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));
// auto dout_names =
// PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
Expand Down Expand Up @@ -523,6 +511,15 @@ inline void NewIRRunProgramAPI(
std::set<std::string>(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(forward_global_block, output_values);
skip_names_set.insert(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names =
details::GetNameFromValue(forward_global_block, no_need_buffer_values);
VLOG(4) << "start skip no need buffer vars with name:";
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Skip no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);

Expand Down Expand Up @@ -997,13 +994,29 @@ inline void NewIRRunProgramGradAPI(
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bx"));
auto forward_middle_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bm"));
auto parameter_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp"));
auto forward_output_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo"));
auto x_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bx_g"));
auto p_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp_g"));

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
middles,
forward_middle_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, params, parameter_values, global_inner_scope);

auto &interpretercore_info_cache =
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
Expand All @@ -1016,19 +1029,6 @@ inline void NewIRRunProgramGradAPI(
1);
VLOG(2) << "No interpretercore cahce, so create a new interpretercore";
// Step 1. share input_vars & parameters into scope
// x, param, middles, output_grads
details::ShareTensorsIntoScopeByValue(backward_global_block,
out_grad,
output_grad_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
middles,
forward_middle_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
auto kernel_backward_program =
paddle::dialect::PdOpLowerToKernelPass(backward_program, place);
interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache(
Expand Down Expand Up @@ -1076,14 +1076,13 @@ inline void NewIRRunProgramGradAPI(
program_id, global_inner_scope, /*is_grad=*/true);
interpreter_core = cached_value.core_;

// update scope (TODO: why share again)
// details::ShareTensorsIntoScope(out_grad, global_inner_scope);
// if (interpreter_core->GetVariableScope()->GetMutableScope() !=
// global_inner_scope) {
// details::BuildScopeByBlock(
// *interpreter_core.get(), *backward_global_block, global_inner_scope);
// interpreter_core->reset_scope(global_inner_scope);
//}
if (interpreter_core->GetVariableScope()->GetMutableScope() !=
global_inner_scope) {
// update scope (TODO(xiongkun): do we need this??)
// details::BuildScopeByBlock(
// *interpreter_core.get(), *backward_global_block, global_inner_scope);
interpreter_core->reset_scope(global_inner_scope);
}
}

if (!backward_global_block->empty()) {
Expand Down Expand Up @@ -1287,7 +1286,7 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase {
~NewIRGradNodeRunProgram() override {
if (!executed_) {
auto *out_scope_vec = &step_scope_;
VLOG(4) << "~GradNodeRunProgram";
VLOG(4) << "~NewIRGradNodeRunProgram";
// Normally out_scope_vec.size() == 1. for safty, we add for-loop here.
for (size_t i = 0; i < out_scope_vec->size(); ++i) {
paddle::framework::Scope *global_inner_scope = out_scope_vec->at(i);
Expand All @@ -1306,7 +1305,7 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase {
egr::kSlotSmallVectorSize> &grads, // NOLINT
bool create_graph UNUSED,
bool is_new_grad UNUSED) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
VLOG(3) << "Running Eager Backward Node: NewIRGradNodeRunProgram";
paddle::small_vector<std::vector<paddle::Tensor>, egr::kSlotSmallVectorSize>
hooked_grads = NewIRGradNodeRunProgram::ApplyGradientHooks(grads);
PADDLE_ENFORCE_EQ(hooked_grads.size(),
Expand Down Expand Up @@ -1348,7 +1347,6 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase {
"The hooked_grads[0].size() and "
"out_grad_values.size() should be equal."));

VLOG(1) << "Run Program Grad API start.";
NewIRRunProgramGradAPI(x_,
params_,
hooked_grads[0],
Expand All @@ -1358,8 +1356,7 @@ class NewIRGradNodeRunProgram : public egr::GradNodeBase {
attrs_,
x_grad_ptr,
params_grad_ptr);
VLOG(1) << "Run Program Grad API end.";
VLOG(3) << "End Eager Backward Node: GradNodeRunProgram";
VLOG(3) << "End Eager Backward Node: NewIRGradNodeRunProgram";

executed_ = true;
return {x_grad, params_grad};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ void ConstructAttrMapForRunProgram(
"fm",
"fo",
"bx",
"no_need_buffers",
"bp",
"bm",
"bo_g",
Expand Down
Loading

0 comments on commit f025068

Please sign in to comment.