Skip to content

Commit

Permalink
[Dy2St] Enable test_lstm in PIR mode (#60343)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Dec 26, 2023
1 parent 23808ae commit a726569
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def tearDown(self):

def run_lstm(self, to_static):
with enable_to_static_guard(to_static):
paddle.static.default_main_program().random_seed = 1001
paddle.static.default_startup_program().random_seed = 1001
paddle.seed(1001)

net = paddle.jit.to_static(Net(12, 2))
x = paddle.zeros((2, 10, 12))
y = net(x)
return y.numpy()

@test_legacy_and_pt_and_pir
def test_lstm_to_static(self):
dygraph_out = self.run_lstm(to_static=False)
static_out = self.run_lstm(to_static=True)
Expand Down

0 comments on commit a726569

Please sign in to comment.