Skip to content

Commit

Permalink
add checkpointoutput for paddle mlperf (PaddlePaddle#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 authored Dec 21, 2021
1 parent 19f0689 commit 04069af
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
9 changes: 7 additions & 2 deletions paddle/fluid/framework/ipu/ipu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ void Compiler::LowerBody(const ir::Graph* graph) {

if (op_type == "popart_constant") {
// pass
} else if (op_type == "popart_optimizer") {
// pass
} else if (op_type == "popart_checkpointoutput") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
auto output_ids = builder_->checkpointOutput(inputs);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_custom_op") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
Expand Down Expand Up @@ -202,8 +209,6 @@ void Compiler::LowerBody(const ir::Graph* graph) {
inputs, print_gradient, debug_context, title);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_optimizer") {
// pass
} else {
auto itr = name_function_.find(op_type);
if (itr != name_function_.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ Node *print_handler(Graph *graph, Node *node) {

Node *popart_optimizer_handler(Graph *graph, Node *node) { return nullptr; }

Node *checkpointoutput_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_checkpointoutput", node->inputs,
node->outputs);
}

REGISTER_HANDLER(custom_op, custom_op_handler);
REGISTER_HANDLER(print, print_handler);
REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler);
REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler);

} // namespace
} // namespace ipu
Expand Down

0 comments on commit 04069af

Please sign in to comment.