Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part 1 (PaddlePaddle#58630)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Nov 3, 2023
1 parent 63370ab commit d8e3a15
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

import paddle
from paddle import set_flags, static
from paddle.base import core
from paddle.jit.api import sot_mode_guard
Expand All @@ -29,9 +30,9 @@
# Usage:
class MyTest(Dy2StTestBase):
@set_to_static_mode(
ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST
ToStaticMode.AST | ToStaticMode.SOT
)
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE | IrMode.PIR_API)
def test_case1(self):
raise ValueError("MyTest 1")
Expand All @@ -49,8 +50,7 @@ def test_case1(self):


class ToStaticMode(Flag):
LEGACY_AST = auto()
PIR_AST = auto()
AST = auto()
SOT = auto()

def lower_case_name(self):
Expand All @@ -59,13 +59,16 @@ def lower_case_name(self):

class IrMode(Flag):
LEGACY_IR = auto()
PIR = auto()
# pir translator mode, Reference link: https://github.com/PaddlePaddle/community/blob/master/pfcc/paddle-code-reading/IR_Dialect/program_translator.md
PIR_EXE = auto()
# using native pir api mode
PIR_API = auto()

def lower_case_name(self):
return self.name.lower()


DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT
DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT
DEFAULT_IR_MODE = IrMode.LEGACY_IR


Expand Down Expand Up @@ -98,13 +101,24 @@ def impl(*args, **kwargs):


def to_pir_ast_test(fn):
raise TypeError("Don't enable PIR AST mode now!")
@wraps(fn)
def impl(*args, **kwargs):
logger.info("[PIR][AST] running pir api")
ir_outs = None
try:
with paddle.pir_utils.IrGuard():
paddle.disable_static()
ir_outs = fn(*args, **kwargs)
finally:
paddle.enable_static()
return ir_outs

return impl


def to_legacy_ir_test(fn):
def impl(*args, **kwargs):
logger.info("[Program] running legacy ir")
# breakpoint()
return fn(*args, **kwargs)

return impl
Expand Down Expand Up @@ -136,13 +150,13 @@ def impl(*args, **kwargs):
class Dy2StTestMeta(type):
TO_STATIC_HANDLER_MAP = {
ToStaticMode.SOT: to_sot_test,
ToStaticMode.LEGACY_AST: to_legacy_ast_test,
ToStaticMode.PIR_AST: to_pir_ast_test,
ToStaticMode.AST: to_legacy_ast_test,
}

IR_HANDLER_MAP = {
IrMode.LEGACY_IR: to_legacy_ir_test,
IrMode.PIR: to_pir_test,
IrMode.PIR_EXE: to_pir_test,
IrMode.PIR_API: to_pir_ast_test,
}

def __new__(cls, name, bases, attrs):
Expand Down Expand Up @@ -191,11 +205,11 @@ def __new__(cls, name, bases, attrs):
)
# Generate all test cases
for to_static_mode, ir_mode in to_static_with_ir_modes:
# NOTE(gouzil): Temporarily not supported SOT + PIR, link: https://github.com/PaddlePaddle/Paddle/pull/58630
if (
to_static_mode == ToStaticMode.PIR_AST
and ir_mode == IrMode.LEGACY_IR
to_static_mode == ToStaticMode.SOT
and ir_mode == IrMode.PIR_API
):
# PIR with LEGACY_IR is not a valid combination
continue
new_attrs[
Dy2StTestMeta.test_case_name(
Expand Down Expand Up @@ -250,7 +264,7 @@ def decorator(fn):
# Suger decorators
# These decorators can be simply composed by base decorators
def test_ast_only(fn):
fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn)
fn = set_to_static_mode(ToStaticMode.AST)(fn)
return fn


Expand All @@ -260,12 +274,22 @@ def test_sot_only(fn):


def test_pir_only(fn):
fn = set_ir_mode(IrMode.PIR)(fn)
fn = set_ir_mode(IrMode.PIR_EXE)(fn)
return fn


def test_legacy_and_pir(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)(fn)
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE)(fn)
return fn


def test_legacy_and_pir_api(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API)
return fn


def test_legacy_and_pir_api_and_pir_exe(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API | IrMode.PIR_EXE)
return fn


Expand Down

0 comments on commit d8e3a15

Please sign in to comment.