Skip to content

Commit

Permalink
[PIR] Change output to block_arg from copy to a shared for the execut…
Browse files Browse the repository at this point in the history
…ion of while (#60607)

* test

* fix

* fix

* fix
  • Loading branch information
zhangbo9674 authored Jan 10, 2024
1 parent 3bcff9e commit 35d445b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ WhileInstruction::WhileInstruction(
body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value));
}
for (auto value : body_outside_inputs) {
body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value));
body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value));
auto name = body_inter_->GetNameByValue(value);
external_input_names_.insert(name);
body_skip_gc_names_.push_back(name);
body_skip_gc_names_set.insert(name);
}
for (const auto& var_name : skip_gc_vars) {
body_skip_gc_names_.push_back(var_name);
Expand Down Expand Up @@ -172,37 +174,26 @@ void WhileInstruction::ShareInputsToOutputs() {
}
}

void WhileInstruction::CopyOutputsToBlockArgs() {
void WhileInstruction::ShareOutputsToBlockArgs() {
for (size_t i = 0; i < body_block_->args_size(); ++i) {
auto block_arg = body_block_->arg(i);
auto var_name = body_inter_->GetNameByValue(block_arg);
auto* inner_var = body_inter_->local_scope()->GetVar(var_name);

if (outputs_[i]->IsType<phi::DenseTensor>()) {
auto& src_tensor = outputs_[i]->Get<phi::DenseTensor>();
auto* dst_tensor = inner_var->GetMutable<phi::DenseTensor>();
dst_tensor->set_meta(src_tensor.meta());
framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
inner_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
outputs_[i]->Get<phi::DenseTensor>());
} else if (outputs_[i]->IsType<phi::TensorArray>()) {
auto src_tensor_array = outputs_[i]->Get<phi::TensorArray>();
auto* dst_tensor_array = inner_var->GetMutable<phi::TensorArray>();
dst_tensor_array->set_type(src_tensor_array.dtype());
dst_tensor_array->set_layout(src_tensor_array.layout());
while (dst_tensor_array->size() < src_tensor_array.size()) {
dst_tensor_array->emplace_back();
}
for (size_t id = 0; id < dst_tensor_array->size(); id++) {
auto& src_tensor = src_tensor_array[id];
phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id);
tmp_dst_tensor->set_meta(src_tensor.meta());
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
}
const auto& outer_array = outputs_[i]->Get<phi::TensorArray>();
auto* inner_array = inner_var->GetMutable<phi::TensorArray>();
*inner_array = outer_array;
VLOG(10) << inner_var
<< " should be created: " << inner_var->IsInitialized();
} else {
PADDLE_THROW(
phi::errors::Unimplemented("unsupported type %d", inner_var->Type()));
}
}
DeviceContext().Wait();
}

void WhileInstruction::ShareDatasToOutputs() {
Expand All @@ -220,6 +211,14 @@ void WhileInstruction::ShareDatasToOutputs() {
out_var->Get<phi::DenseTensor>());
VLOG(6) << "share data from " << out_var_name << "[" << out_var << "]"
<< " -> " << i << " output[" << outputs_[i] << "]";

// NOTE(zhangbo): Delete the input of the yield operator, except for the
// external vars of the block.
if (external_input_names_.count(out_var_name) == 0) {
VLOG(6) << "clear internel input " << out_var_name;
out_var->GetMutable<phi::DenseTensor>()->clear();
}

} else if (out_var->IsType<phi::TensorArray>()) {
const auto& inner_array = out_var->Get<phi::TensorArray>();
auto* output_array = outputs_[i]->GetMutable<phi::TensorArray>();
Expand All @@ -238,7 +237,7 @@ void WhileInstruction::Run() {
VLOG(6) << "while instruction start loop ...";
while (GetCondData(cond_var_->Get<phi::DenseTensor>())) {
VLOG(6) << "while instruction pass args to body block";
CopyOutputsToBlockArgs();
ShareOutputsToBlockArgs();
VLOG(6) << "while instruction interpretercore run";
body_inter_->Run({}, false);
VLOG(6) << "while instruction get value form body block";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class WhileInstruction : public InstructionBase {
void ShareInputsToOutputs();

// Pass argument to body_block for execution.
void CopyOutputsToBlockArgs();
void ShareOutputsToBlockArgs();

// Get return value from body_block after each execution.
void ShareDatasToOutputs();
Expand All @@ -70,6 +70,7 @@ class WhileInstruction : public InstructionBase {
std::unique_ptr<PirInterpreter> body_inter_;
std::vector<std::string> body_outputs_;
std::vector<std::string> body_skip_gc_names_;
std::set<std::string> external_input_names_;

::pir::Block* body_block_;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool CanDoInplace(const std::unordered_set<pir::Value>& eager_dels,
pir::Value input,
pir::Value output,
const std::string& op_name) {
if (!input.type() || !output.type()) {
if (!input.type() || !output.type() || input.isa<pir::BlockArgument>()) {
return false;
}

Expand Down

0 comments on commit 35d445b

Please sign in to comment.