From 04069af9cfb3a0e595fcc35482801a13305ddac7 Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Tue, 21 Dec 2021 15:19:26 +0800 Subject: [PATCH] add checkpointoutput for paddle mlperf (#346) --- paddle/fluid/framework/ipu/ipu_compiler.cc | 9 +++++++-- .../framework/ipu/popart_canonicalization/other_ops.cc | 6 ++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ipu/ipu_compiler.cc b/paddle/fluid/framework/ipu/ipu_compiler.cc index 09853cbdd5ce4..400f181ef25c9 100644 --- a/paddle/fluid/framework/ipu/ipu_compiler.cc +++ b/paddle/fluid/framework/ipu/ipu_compiler.cc @@ -173,6 +173,13 @@ void Compiler::LowerBody(const ir::Graph* graph) { if (op_type == "popart_constant") { // pass + } else if (op_type == "popart_optimizer") { + // pass + } else if (op_type == "popart_checkpointoutput") { + auto inputs = GetOpInputs(op_desc); + auto outputs = GetOpOutputs(op_desc); + auto output_ids = builder_->checkpointOutput(inputs); + InsertTensors(outputs, output_ids); } else if (op_type == "popart_custom_op") { auto inputs = GetOpInputs(op_desc); auto outputs = GetOpOutputs(op_desc); @@ -202,8 +209,6 @@ void Compiler::LowerBody(const ir::Graph* graph) { inputs, print_gradient, debug_context, title); SetIpuIndexStage(output_ids, op_desc); InsertTensors(outputs, output_ids); - } else if (op_type == "popart_optimizer") { - // pass } else { auto itr = name_function_.find(op_type); if (itr != name_function_.end()) { diff --git a/paddle/fluid/framework/ipu/popart_canonicalization/other_ops.cc b/paddle/fluid/framework/ipu/popart_canonicalization/other_ops.cc index 6e5181e6138e1..125985fb07896 100644 --- a/paddle/fluid/framework/ipu/popart_canonicalization/other_ops.cc +++ b/paddle/fluid/framework/ipu/popart_canonicalization/other_ops.cc @@ -49,9 +49,15 @@ Node *print_handler(Graph *graph, Node *node) { Node *popart_optimizer_handler(Graph *graph, Node *node) { return nullptr; } +Node *checkpointoutput_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_checkpointoutput", node->inputs, + node->outputs); +} + REGISTER_HANDLER(custom_op, custom_op_handler); REGISTER_HANDLER(print, print_handler); REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler); +REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler); } // namespace } // namespace ipu