diff --git a/python/taichi/linalg/matrixfree_cg.py b/python/taichi/linalg/matrixfree_cg.py index 872ff65139044..3daabf3e0ae70 100644 --- a/python/taichi/linalg/matrixfree_cg.py +++ b/python/taichi/linalg/matrixfree_cg.py @@ -64,7 +64,6 @@ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): beta = ti.field(dtype=solver_dtype) scalar_builder.place(alpha, beta) scalar_snode_tree = scalar_builder.finalize() - succeeded = True @ti.kernel def init(): @@ -96,6 +95,7 @@ def update_p(): p[I] = r[I] + beta[None] * p[I] def solve(): + succeeded = True A._matvec(x, Ax) init() initial_rTr = reduce(r, r) @@ -129,8 +129,9 @@ def solve(): f">>> Conjugate Gradient method failed to converge in {maxiter} iterations: Residual = {sqrt(new_rTr):e}" ) succeeded = False + return succeeded - solve() + succeeded = solve() vector_fields_snode_tree.destroy() scalar_snode_tree.destroy() return succeeded @@ -252,6 +253,7 @@ def update_r(): r[I] = s[I] - omega[None] * t[I] def solve(): + succeeded = True A._matvec(x, Ax) init() initial_rTr = reduce(r, r) @@ -296,8 +298,9 @@ def solve(): if not quiet: print(f">>> BICGSTAB failed to converge in {maxiter} iterations: Residual = {sqrt(rTr):e}") succeeded = False + return succeeded - solve() + succeeded = solve() vector_fields_snode_tree.destroy() scalar_snode_tree.destroy() return succeeded