Skip to content

Commit

Permalink
cholesky and cholesky_solve tests (#60726)
Browse files Browse the repository at this point in the history
  • Loading branch information
changeyoung98 authored Jan 11, 2024
1 parent 0ac9c29 commit aef9d6d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 20 deletions.
6 changes: 3 additions & 3 deletions test/legacy_test/gradient_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _compute_analytical_jacobian_pir(
def grad_check(
x,
y,
x_init=None,
fetch_list=None,
feeds=None,
place=None,
program=None,
Expand Down Expand Up @@ -403,12 +403,12 @@ def fail_test(msg):
for i in range(len(y)):
analytical.append(
_compute_analytical_jacobian_pir(
program, x, i, y, x_init, feeds, place
program, x, i, y, fetch_list, feeds, place
)
)
numerical = [
_compute_numerical_jacobian_pir(
program, xi, y, x_init, feeds, place, eps
program, xi, y, fetch_list, feeds, place, eps
)
for xi in x
]
Expand Down
57 changes: 44 additions & 13 deletions test/legacy_test/test_cholesky_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import core
from paddle.base.backward import _as_list
from paddle.pir_utils import test_with_pir_api


@skip_check_grad_ci(
Expand Down Expand Up @@ -68,33 +69,38 @@ def test_check_grad(self):
for p in places:
self.func(p)

@test_with_pir_api
@prog_scope()
def func(self, place):
# use small size since Jacobian gradients is time consuming
root_data = self.root_data[..., :3, :3]
prog = base.Program()
with base.program_guard(prog):
root = paddle.create_parameter(
dtype=root_data.dtype, shape=root_data.shape
)
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
if paddle.framework.in_pir_mode():
root = paddle.static.data(
dtype=root_data.dtype, shape=root_data.shape, name="root"
)
else:
root = paddle.create_parameter(
dtype=root_data.dtype, shape=root_data.shape
)
root.stop_gradient = False
root.persistable = True
root_t = paddle.transpose(root, self.trans_dims)
x = paddle.matmul(x=root, y=root_t) + 1e-05
out = paddle.cholesky(x, upper=self.attrs["upper"])
# check input arguments
root = _as_list(root)
out = _as_list(out)

for v in root:
v.stop_gradient = False
v.persistable = True
for u in out:
u.stop_gradient = False
u.persistable = True

# init variable in startup program
scope = base.executor.global_scope()
exe = base.Executor(place)
exe.run(base.default_startup_program())
exe.run(paddle.static.default_startup_program())

x_init = _as_list(root_data)
# init inputs if x_init is not None
Expand All @@ -106,10 +112,33 @@ def func(self, place):
)
# init variable in main program
for var, arr in zip(root, x_init):
assert var.shape == arr.shape
assert tuple(var.shape) == tuple(arr.shape)
feeds = {k.name: v for k, v in zip(root, x_init)}
exe.run(prog, feed=feeds, scope=scope)
grad_check(x=root, y=out, x_init=x_init, place=place, program=prog)
fetch_list = None
if paddle.framework.in_pir_mode():
dys = []
for i in range(len(out)):
yi = out[i]
dy = paddle.static.data(
name='dys_%s' % i,
shape=yi.shape,
dtype=root_data.dtype,
)
dy.stop_gradient = False
dy.persistable = True
value = np.zeros(yi.shape, dtype=root_data.dtype)
feeds.update({'dys_%s' % i: value})
dys.append(dy)
fetch_list = base.gradients(out, root, dys)
grad_check(
x=root,
y=out,
fetch_list=fetch_list,
feeds=feeds,
place=place,
program=prog,
)

def init_config(self):
self._upper = True
Expand Down Expand Up @@ -144,8 +173,11 @@ def setUp(self):
if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
self.places.append(base.CUDAPlace(0))

@test_with_pir_api
def check_static_result(self, place, with_out=False):
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = paddle.static.data(
name="input", shape=[4, 4], dtype="float64"
)
Expand All @@ -156,7 +188,6 @@ def check_static_result(self, place, with_out=False):
exe = base.Executor(place)
try:
fetches = exe.run(
base.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
Expand Down
15 changes: 11 additions & 4 deletions test/legacy_test/test_cholesky_solve_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand Down Expand Up @@ -143,7 +144,7 @@ def test_check_output(self):

# check Op grad
def test_check_grad_normal(self):
self.check_grad(['Y'], 'Out', max_relative_error=0.01)
self.check_grad(['Y'], 'Out', max_relative_error=0.01, check_pir=True)


# test condition: 3D(broadcast) + 3D, upper=True
Expand All @@ -169,9 +170,12 @@ def setUp(self):
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

@test_with_pir_api
def check_static_result(self, place):
paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name="x", shape=[10, 2], dtype=self.dtype)
y = paddle.static.data(name="y", shape=[10, 10], dtype=self.dtype)
z = paddle.linalg.cholesky_solve(x, y, upper=self.upper)
Expand All @@ -187,7 +191,6 @@ def check_static_result(self, place):

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"x": x_np, "y": umat},
fetch_list=[z],
)
Expand Down Expand Up @@ -239,7 +242,7 @@ def run(place):

# test condition out of bounds
class TestCholeskySolveOpError(unittest.TestCase):
def test_errors(self):
def test_errors_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
# The input type of solve_op must be Variable.
Expand All @@ -251,6 +254,10 @@ def test_errors(self):
)
self.assertRaises(TypeError, paddle.linalg.cholesky_solve, x1, y1)

@test_with_pir_api
def test_errors_2(self):
paddle.enable_static()
with program_guard(Program(), Program()):
# The data type of input must be float32 or float64.
x2 = paddle.static.data(name="x2", shape=[30, 30], dtype="bool")
y2 = paddle.static.data(name="y2", shape=[30, 10], dtype="bool")
Expand Down

0 comments on commit aef9d6d

Please sign in to comment.