Skip to content

Commit

Permalink
[Dy2St][PIR] Run test_break_continue in sequential run mode (#63287)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Apr 8, 2024
1 parent 5b4132d commit de0cd61
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
17 changes: 17 additions & 0 deletions test/dygraph_to_static/dygraph_to_static_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def test_case1(self):
ENV_ENABLE_PIR_WITH_PT_IN_DY2ST = BooleanEnvironmentVariable(
"FLAGS_enable_pir_with_pt_in_dy2st", True
)
ENV_EXE_SEQUENTIAL_RUN = BooleanEnvironmentVariable(
"FLAGS_new_executor_sequential_run", False
)


class ToStaticMode(Flag):
Expand Down Expand Up @@ -438,3 +441,17 @@ def enable_to_static_guard(flag: bool):
yield
finally:
program_translator.enable(original_flag_value)


@contextmanager
def exe_sequential_run_guard(value: bool):
exe_sequential_run_flag = ENV_EXE_SEQUENTIAL_RUN.name
original_flag_value = paddle.get_flags(exe_sequential_run_flag)[
exe_sequential_run_flag
]
with EnvironmentVariableGuard(ENV_EXE_SEQUENTIAL_RUN, value):
try:
set_flags({exe_sequential_run_flag: value})
yield
finally:
set_flags({exe_sequential_run_flag: original_flag_value})
12 changes: 10 additions & 2 deletions test/dygraph_to_static/test_break_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
exe_sequential_run_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle.framework import use_pir_api
from paddle.jit.dy2static.utils import Dygraph2StaticException

SEED = 2020
Expand Down Expand Up @@ -354,11 +356,17 @@ class TestOptimBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_while

# TODO: Open PIR test when while_loop dy2st fixed
@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
# NOTE(SigureMo): Temperary run the test in sequential run mode to avoid dependency
# on the execution order of the test cases.
if use_pir_api():
with exe_sequential_run_guard(True):
static_res = self.run_static_mode()
else:
static_res = self.run_static_mode()
np.testing.assert_allclose(
dygraph_res,
static_res,
Expand Down

0 comments on commit de0cd61

Please sign in to comment.