Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][PIR] Hold backward program in GradNode #63694

Merged
merged 25 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,7 @@ inline void pir_run_program_ad_func(

grad_node->SetStepScope(step_scope); // just for set useable.

// Set Grad out rank as same as fwd input and set stop gradient to bwd
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.

std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
x_require_grad.push_back(&x[i]);
}

grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0);
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);

// TODO(@xiongkun): rewrite by new ir representation.
Expand Down
70 changes: 35 additions & 35 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,21 +467,16 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
std::shared_ptr<::pir::Program> forward_program = PADDLE_GET_CONST(
std::shared_ptr<::pir::Program>, attrs.at("forward_program"));
std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST(
std::shared_ptr<::pir::Program>, attrs.at("backward_program"));

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
Expand Down Expand Up @@ -509,12 +504,12 @@ inline void PirRunProgramAPI(
<< program_id;
// Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// Step 2. create new interpretercore
auto passed_kernel_program =
paddle::framework::ApplyIrPass(forward_program, place);
paddle::framework::ApplyIrPass(forward_program.get(), place);
if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "LoweredProgram( AfterPass ) is :\n";
Expand All @@ -535,22 +530,22 @@ inline void PirRunProgramAPI(

// update interpretercore skip_gc_var
auto skip_names = details::GetNameFromValue(
forward_global_block, middle_values, false, true);
forward_program->block(), middle_values, false, true);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false, true);
forward_program->block(), no_need_buffer_values, false, true);
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names = details::GetNameFromValue(
forward_global_block, output_values, false, true);
forward_program->block(), output_values, false, true);
skip_names_set.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
forward_global_block, input_values, true, false);
forward_program->block(), input_values, true, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);
Expand All @@ -576,9 +571,9 @@ inline void PirRunProgramAPI(
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// TODO(xiongkun): new ir how to build scope.
// if (interpreter_core->GetVariableScope()->GetMutableScope() !=
// global_inner_scope) {
Expand All @@ -589,7 +584,7 @@ inline void PirRunProgramAPI(
}

// interpretercore run
if (!forward_global_block->empty()) {
if (!forward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -602,7 +597,7 @@ inline void PirRunProgramAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Get Output, and Middle Outputs
details::ShareTensorsFromScopeByValue(
forward_global_block, out, output_values, global_inner_scope);
forward_program->block(), out, output_values, global_inner_scope);

VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());

Expand Down Expand Up @@ -1041,10 +1036,8 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST(
std::shared_ptr<::pir::Program>, attrs.at("backward_program"));

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
Expand All @@ -1064,8 +1057,10 @@ inline void PirRunProgramGradAPI(
details::Trans2ContiguousTensorsInplace(out_grad);

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);

auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
Expand All @@ -1082,7 +1077,7 @@ inline void PirRunProgramGradAPI(
VLOG(2) << "No interpretercore cache, so create a new interpretercore";
// Step 1. share input_vars & parameters into scope
auto passed_kernel_program =
paddle::framework::ApplyIrPass(backward_program, place);
paddle::framework::ApplyIrPass(backward_program.get(), place);

const auto &new_block = passed_kernel_program->block();
passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass(
Expand Down Expand Up @@ -1124,10 +1119,10 @@ inline void PirRunProgramGradAPI(
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names = details::GetNameFromValue(
backward_global_block, x_grad_values, false, true);
backward_program->block(), x_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
backward_global_block, p_grad_values, false, true);
backward_program->block(), p_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
cache.UpdateSkipEagerDeleteVars(program_id,
Expand Down Expand Up @@ -1160,7 +1155,7 @@ inline void PirRunProgramGradAPI(
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -1175,9 +1170,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
Expand Down Expand Up @@ -1316,8 +1313,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
if (x[i].is_dense_tensor()) {
x_grad->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (x[i].is_selected_rows()) {
auto selected_row = std::make_shared<phi::SelectedRows>();
x_grad->emplace_back(selected_row);
x_grad->emplace_back(std::make_shared<phi::SelectedRows>());
}
x_grad->back().set_name(x_grad_names[i]);
}
Expand Down Expand Up @@ -1446,6 +1442,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram";

*executed_ = true;
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad,
this->OutputMeta()[0]);
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&params_grad,
this->OutputMeta()[1]);
return {x_grad, params_grad};
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"

Expand Down Expand Up @@ -977,6 +978,9 @@ struct SetAttrDescVisitor {
void operator()(const std::vector<pir::Block *> &v) const {
// just do nothing.
}
void operator()(const std::shared_ptr<pir::Program> &v) const {
// just do nothing.
}
void operator()(const std::vector<VarDesc *> &v) const {
std::vector<std::string> var_names;
for (auto var : v) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/type_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ template class variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
} // namespace paddle
REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap);
REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute);
4 changes: 3 additions & 1 deletion paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"
#include "paddle/utils/small_vector.h"
Expand Down Expand Up @@ -67,7 +68,8 @@ using Attribute = paddle::variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;

using OpCreator =
Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj,
attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]);
}

void CastPyArg2AttrIRProgram(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
VLOG(1) << "After Process pir::Program*";
const std::shared_ptr<::pir::Program> program =
::py::handle(obj).cast<std::shared_ptr<::pir::Program>>();
attrs[key] = program;
}

void CastPyArg2AttrValues(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
Expand Down Expand Up @@ -1020,11 +1031,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
.count(key)) {
} else if (std::set<std::string>({"global_block"}).count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ void BindProgram(py::module *m) {
)DOC");
program
.def(py::init([]() {
return std::make_unique<Program>(pir::IrContext::Instance());
return std::make_shared<Program>(pir::IrContext::Instance());
}))
.def("__str__",
[](const std::shared_ptr<Program> &self) {
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,10 +914,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/prim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(init_env_utils SRCS init_env_utils.cc)
target_compile_definitions(init_env_utils PUBLIC PADDLE_DLL_EXPORT)

paddle_test(test_comp_eager SRCS test_eager_prim.cc DEPS init_env_utils)
paddle_test(test_comp_eager SRCS test_eager_prim.cc init_env_utils.cc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#56691 看样子是为了减小单测体积才这么搞的,这样改是不是又变回去了?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,但是它们之间目前存在重复依赖的问题,会导致windows LNK2005错误,所以就先直接不拆分了

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重复依赖是指?如果这样的话,24 行是不是不需要了?还是说后续有计划优化这里呢?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重复依赖可以看13c89cf,以及它的PR-CI-Windows-OPENBLAS,大概就是它俩都依赖了phi导致重复依赖了。也参考了一些:paddle_test 的文档还是没能拆出来。最好的情况当然是拆分它,缩小体积,后续优化吧。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好的情况当然是拆分它,缩小体积,后续优化吧。

ok

endif()

# skip win32 since wget is not installed by default on windows machine.
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir

import paddle

Expand All @@ -33,6 +33,7 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])
Expand Down