From 02d7ba27f77941ec8b457fc4e81edaefc3444b4a Mon Sep 17 00:00:00 2001 From: liblaf <30631553+liblaf@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:59:22 +0800 Subject: [PATCH] [bug] Ensure `succeeded` variable is properly initialized in matrix-free solvers The `succeeded` variable was not properly initialized in the `MatrixFreeCG` and `MatrixFreeBICGSTAB` functions, leading to potential issues with the convergence check. By initializing the `succeeded` variable at the beginning of the `solve` function, we ensure that the variable is correctly set and returned at the end of the function, improving the reliability of the solvers. --- python/taichi/linalg/matrixfree_cg.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/taichi/linalg/matrixfree_cg.py b/python/taichi/linalg/matrixfree_cg.py index 872ff65139044..3daabf3e0ae70 100644 --- a/python/taichi/linalg/matrixfree_cg.py +++ b/python/taichi/linalg/matrixfree_cg.py @@ -64,7 +64,6 @@ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): beta = ti.field(dtype=solver_dtype) scalar_builder.place(alpha, beta) scalar_snode_tree = scalar_builder.finalize() - succeeded = True @ti.kernel def init(): @@ -96,6 +95,7 @@ def update_p(): p[I] = r[I] + beta[None] * p[I] def solve(): + succeeded = True A._matvec(x, Ax) init() initial_rTr = reduce(r, r) @@ -129,8 +129,9 @@ def solve(): f">>> Conjugate Gradient method failed to converge in {maxiter} iterations: Residual = {sqrt(new_rTr):e}" ) succeeded = False + return succeeded - solve() + succeeded = solve() vector_fields_snode_tree.destroy() scalar_snode_tree.destroy() return succeeded @@ -252,6 +253,7 @@ def update_r(): r[I] = s[I] - omega[None] * t[I] def solve(): + succeeded = True A._matvec(x, Ax) init() initial_rTr = reduce(r, r) @@ -296,8 +298,9 @@ def solve(): if not quiet: print(f">>> BICGSTAB failed to converge in {maxiter} iterations: Residual = {sqrt(rTr):e}") succeeded = False + return succeeded - solve() + succeeded = solve() vector_fields_snode_tree.destroy() scalar_snode_tree.destroy() return succeeded