Skip to content

Commit

Permalink
[PIR+CINN]Part-2 Pybind IrParser.ParseProgram and Polish UT into chec…
Browse files Browse the repository at this point in the history
…k_run (#59449)

* [PIR+CINN]Support SubGraph Exporter for Unittest Platform

add unittest

fix UT not take effect

[PIR+CINN]Pybind IrParser.ParseProgram and Polish UT into check_run

fix cmake flasgs

remove VLOG

fix code comment

* fix conflict

* remove print

* fix UT

* add list.sort to fix random
  • Loading branch information
Aurelius84 authored Nov 30, 2023
1 parent d86f686 commit 09401e6
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 82 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/parser/ir_parser.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/type.h"
#include "paddle/pir/core/value.h"
Expand Down Expand Up @@ -80,6 +81,7 @@ using paddle::dialect::SelectedRowsType;
using pir::Attribute;
using pir::Block;
using pir::BlockArgument;
using pir::IrParser;
using pir::Operation;
using pir::OpOperand;
using pir::OpResult;
Expand Down Expand Up @@ -254,6 +256,15 @@ void BindProgram(py::module *m) {
});
}

std::shared_ptr<Program> ParseProgram(const std::string &program_str) {
std::stringstream ss(program_str);
pir::IrContext *ctx = pir::IrContext::Instance();
auto program = IrParser(ctx, ss).ParseProgram();
return program;
}

void BindIrParser(py::module *m) { m->def("parse_program", &ParseProgram); }

void RefreshOpStopgradients(Operation *op) {
if (op->num_operands() == 0 || op->isa<pir::ParameterOp>() ||
op->isa<paddle::dialect::UniformOp>()) {
Expand Down Expand Up @@ -1612,6 +1623,7 @@ void BindPir(pybind11::module *module) {
BindControlFlowApi(&ir_module);
auto ops_modules = ir_module.def_submodule("ops");
BindOpsAPI(&ops_modules);
BindIrParser(&ir_module);
}

} // namespace pybind
Expand Down
39 changes: 19 additions & 20 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,25 +428,23 @@ def has_fetch_operations(
that match the info contained in fetch_targets.
"""

fetch_count = 0
mismatch_count = 0
fetch_info = [[], []]
for op in block.ops:
if op.name() == fetch_op:
if op.operand_source(0) not in fetch_targets:
mismatch_count += 1
continue
fetch_count += 1
if mismatch_count > 0:
warnings.warn(
"There are {} fetch ops in Program which are not responsible for the fetch targets that you have passed in fetch_list".format(
mismatch_count
)
)
if fetch_count > 0 and fetch_count != len(fetch_targets):
raise Exception(
"Fetch operations in program do not match 'fetch_targets'"
)
return fetch_count > 0
fetch_info[0].append(op.operand_source(0))
fetch_info[1].append(op.attrs()["name"])

need_fetch_info = []
for i, fetch_var in enumerate(fetch_targets):
if isinstance(fetch_var, str):
if fetch_var not in fetch_info[1]:
raise Exception(
f"Found fetch_target[{i}] is type(str) and doesn't have fetch op."
)
elif fetch_var not in fetch_info[0]:
need_fetch_info.append(fetch_var)

return need_fetch_info


def _add_feed_fetch_ops(
Expand Down Expand Up @@ -519,11 +517,12 @@ def _add_pir_fetch_ops(program, fetch_list, fetch_var_name):

global_block = program.global_block()
fetch_op = "pd_op.fetch"
if not has_fetch_operations(
need_fetch_info = has_fetch_operations(
global_block, fetch_list, fetch_var_name, fetch_op
):
)
if need_fetch_info:
with paddle.static.program_guard(program):
for i, fetch_input in enumerate(fetch_list):
for i, fetch_input in enumerate(need_fetch_info):
assert isinstance(
fetch_input, (OpResult, Value)
), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}"
Expand Down
49 changes: 18 additions & 31 deletions python/paddle/jit/dy2static/export_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle import pir
from paddle.base import core
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.base.framework import Variable, get_flags
from paddle.base.framework import get_flags

__all__ = []

Expand All @@ -42,6 +42,7 @@ def __init__(self, partial_program_layer, program, role):
self.program = program
self.role = role
self.root_dir = get_saving_dir()
self.fetch_col = 0

def save(self):
# step 1: Create subgraph saving path.
Expand All @@ -59,6 +60,9 @@ def _save(self, pir_program, path):
f.write(content)

def parse_inout(self):
"""
Return feed/fetch/intermediate var name list.
"""
raise NotImplementedError("Need to implement parse_inout method")

def translate_into_pir(self):
Expand Down Expand Up @@ -90,9 +94,8 @@ def verify_saving_dir(self, dir_path):

def insert_feed_op(self, intputs, rename_prefix):
global_block = self.program.block(0)

for i, var in enumerate(intputs):
old_name = var.name
intputs.sort()
for i, old_name in enumerate(intputs):
new_name = rename_prefix + str(i)
global_block._rename_var(old_name, new_name)
out = global_block.var(new_name)
Expand All @@ -116,32 +119,20 @@ def insert_fetch_op(self, outputs, rename_prefix):
type=core.VarDesc.VarType.FETCH_LIST,
persistable=False,
)
for i, out in enumerate(outputs):
var = self.get_var(out)
old_name = var.name
outputs.sort()
for i, old_name in enumerate(outputs):
new_name = rename_prefix + str(i)
global_block._rename_var(old_name, new_name)
new_var = global_block.var(new_name)
global_block.append_op(
type="fetch",
inputs={'X': [new_var]},
outputs={'Out': [fetch_var]},
attrs={'col': i},
attrs={'col': self.fetch_col},
)
self.fetch_col += 1
global_block._sync_with_cpp()

def rename_ops(self, ops, new_name, old_name):
for op in ops:
op._rename_input(old_name, new_name)
op._rename_output(old_name, new_name)

def get_var(self, name_or_var):
if isinstance(name_or_var, Variable):
return name_or_var
assert isinstance(name_or_var, str)
global_block = self.program.block(0)
return global_block.var(name_or_var)


class InferExporter(BaseExporter):
def __init__(self, *args, **kwargs):
Expand All @@ -153,12 +144,10 @@ def parse_inout(self):
raw_inputs = self.pp_layer._inputs.tolist() + self.pp_layer._params
raw_outputs = self.pp_layer._outputs.tolist()
for var in raw_inputs:
new_var = global_block.var(var.name)
inputs.append(new_var)
inputs.append(var.name)

for var in raw_outputs:
new_var = global_block.var(var.name)
outputs.append(new_var)
outputs.append(var.name)

return inputs, outputs, []

Expand All @@ -180,14 +169,12 @@ def parse_inout(self):
if self.program.block(0).has_var(name)
}
for var in raw_inputs:
new_var = global_block.var(var.name)
inputs.append(new_var)
inputs.append(var.name)
if var.name in inter_outs:
inter_outs.remove(var.name)

for var in raw_outputs:
new_var = global_block.var(var.name)
outputs.append(new_var)
outputs.append(var.name)
if var.name in inter_outs:
inter_outs.remove(var.name)

Expand All @@ -206,22 +193,22 @@ def parse_inout(self):

for var_name in self.raw_inputs:
if global_block.has_var(var_name):
inputs.append(global_block.var(var_name))
inputs.append(var_name)

# add fill_constant grad_var as input
for var in self.pp_layer._outputs.tolist():
init_grad_name = var.name + "@GRAD"
if init_grad_name not in self.raw_inputs and global_block.has_var(
init_grad_name
):
inputs.append(global_block.var(init_grad_name))
inputs.append(init_grad_name)

for var_name in self.raw_outputs:
if (
global_block.has_var(var_name)
and var_name not in self.raw_inputs
):
outputs.append(global_block.var(var_name))
outputs.append(var_name)

return inputs, outputs, []

Expand Down
1 change: 1 addition & 0 deletions python/paddle/pir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Program,
Type,
Value,
parse_program,
check_unregistered_ops,
fake_op_result,
is_fake_op_result,
Expand Down
7 changes: 6 additions & 1 deletion test/ir/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ set(TEST_IR_SYSTEM_CASES
list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES})
list(REMOVE_ITEM TEST_INTERP_CASES test_subgraph_exporter)
py_test_modules(
test_subgraph_exporter MODULES test_subgraph_exporter ENVS MIN_GRAPH_SIZE=0
test_subgraph_exporter
MODULES
test_subgraph_exporter
ENVS
MIN_GRAPH_SIZE=0
FLAGS_enable_pir_in_executor=1
FLAGS_pir_subgraph_saving_dir=${CMAKE_CURRENT_SOURCE_DIR})

foreach(target ${TEST_INTERP_CASES})
Expand Down
114 changes: 84 additions & 30 deletions test/ir/pir/test_subgraph_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import shutil
import unittest

import numpy as np

import paddle
from paddle.jit.dy2static.export_subgraph import get_saving_dir

Expand Down Expand Up @@ -50,53 +52,105 @@ def test_export(self):
out = self.net(x)
self.check_export()

def run_program(self, program, feed, fetch_list):
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
outs = exe._run_pir_impl(
program,
feed=feed,
fetch_list=fetch_list,
feed_var_name="feed",
fetch_var_name='fetch',
scope=None,
return_numpy=True,
)
paddle.disable_static()
return outs

def check_export(self):
for prog_file in os.listdir(self.root_dir):
if "forward" in prog_file:
self.check_fwd(prog_file)
return
elif "backward" in prog_file:
self.check_bwd(prog_file)
else:
raise RuntimeError("Not Support.")

def check_fwd(self, prog_file):
prog_info = [
"pt_input_0",
"pt_output_0",
"pt_output_1",
"pt_intermediate_0",
"pt_intermediate_1",
"pt_intermediate_2",
]
path = os.path.join(self.root_dir, prog_file)
with open(path, 'r') as f:
content = f.readlines()
index = 0
for op_str in content:
if "pd_op.data" in op_str or "pd_op.fetch" in op_str:
self.assertIn(prog_info[index], op_str)
index += 1
content = f.read()
program = paddle.pir.parse_program(content)

def check_bwd(self, prog_file):
prog_info = [
"pt_input_6",
"pt_input_5",
"pt_input_4",
"pt_input_3",
"pt_input_2",
"pt_input_1",
"pt_input_0",
pt_input_0 = np.random.random([4, 4]).astype(np.float32)
feed = {"pt_input_0": pt_input_0}
fetch_list = [
'pt_output_0',
'pt_output_1',
'pt_intermediate_0',
'pt_intermediate_1',
'pt_intermediate_2',
]
outs = self.run_program(program, feed, fetch_list)

self.assertEqual(len(outs), 5)
out_shapes = [[4, 4], [], [4, 4], [4, 4], [4, 4]]
for i, out in enumerate(outs):
self.assertListEqual(list(out.shape), out_shapes[i])

def check_bwd(self, prog_file):
path = os.path.join(self.root_dir, prog_file)
with open(path, 'r') as f:
content = f.readlines()
index = 0
for op_str in content:
if "pd_op.data" in op_str or "pd_op.fetch" in op_str:
self.assertIn(prog_info[index], op_str)
index += 1
content = f.read()

program = paddle.pir.parse_program(content)
data = np.random.random([4, 4]).astype(np.float32)
feed = {
"pt_input_6": data,
"pt_input_5": data,
"pt_input_4": data,
"pt_input_3": np.array(0.1).astype(np.float32),
"pt_input_2": data,
"pt_input_1": data,
"pt_input_0": data,
}
fetch_list = []
outs = self.run_program(program, feed, fetch_list)

self.assertEqual(len(outs), 0)


# class TestSaveInferProg(TestSaveFwdBwdProg):

# def test_export(self):
# x = paddle.randn([4, 4])
# self.net.eval()
# out = self.net(x)
# self.check_export()

# def check_export(self):
# for prog_file in os.listdir(self.root_dir):
# breakpoint()
# if "infer" in prog_file:
# self.check_infer(prog_file)
# else:
# raise RuntimeError("Not Support.")

# def check_infer(self, prog_file):
# path = os.path.join(self.root_dir, prog_file)
# with open(path, 'r') as f:
# content = f.read()
# program = paddle.pir.parse_program(content)

# pt_input_0 = np.random.random([4,4]).astype(np.float32)
# feed = {"pt_input_0": pt_input_0}
# fetch_list = ['pt_output_0', 'pt_output_1']
# outs = self.run_program(program, feed, fetch_list)

# self.assertEqual(len(outs), 2)
# out_shapes = [[], [4,4]]
# for i, out in enumerate(outs):
# self.assertListEqual(list(out.shape), out_shapes[i])

if __name__ == "__main__":
unittest.main()

0 comments on commit 09401e6

Please sign in to comment.