Skip to content

Commit

Permalink
[PIR] Adapt 0D create_parameter ut (#63253)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Apr 7, 2024
1 parent 6d715a3 commit 81ae815
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions test/legacy_test/test_zero_dim_sundry_static_api_part1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from decorator_helper import prog_scope

import paddle
from paddle.framework import in_pir_mode
from paddle.pir_utils import test_with_pir_api

# Use to test zero-dim of Sundry API, which is unique and can not be classified
Expand Down Expand Up @@ -125,14 +126,28 @@ def test_trapezoid(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5,))

@test_with_pir_api
@prog_scope()
def test_create_parameter_var(self):
def test_create_parameter(self):
if not in_pir_mode():
zero_dim_param = paddle.create_parameter(shape=[], dtype='float32')
self.assertShapeEqual(zero_dim_param, [])
prog = paddle.static.default_startup_program()
res = self.exe.run(prog, fetch_list=[zero_dim_param])
self.assertEqual(res[0].shape, ())
return
zero_dim_param = paddle.create_parameter(shape=[], dtype='float32')
self.assertShapeEqual(zero_dim_param, [])
prog = paddle.static.default_startup_program()
res = self.exe.run(prog, fetch_list=[zero_dim_param])
self.assertEqual(res[0].shape, ())
self.assertEqual(zero_dim_param.shape, [])
startup_prog = paddle.static.default_startup_program()
main_prog = paddle.static.default_main_program()
self.exe.run(startup_prog)
(zero_dim_param_res,) = self.exe.run(
main_prog, fetch_list=[zero_dim_param]
)
self.assertEqual(zero_dim_param_res.shape, ())

@prog_scope()
def test_create_global_var(self):
zero_dim_var = paddle.static.create_global_var(
shape=[], value=0.5, dtype='float32'
)
Expand Down

0 comments on commit 81ae815

Please sign in to comment.