From 9ffb94d2ca9149bbd69b8919d14169a942b4702f Mon Sep 17 00:00:00 2001 From: houkensjtu Date: Thu, 4 May 2023 19:42:43 +0800 Subject: [PATCH] Add docstring to SparseCG and MatrixFreeCG. --- python/taichi/linalg/matrixfree_cg.py | 14 ++++++++++++++ python/taichi/linalg/sparse_cg.py | 12 ++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/taichi/linalg/matrixfree_cg.py b/python/taichi/linalg/matrixfree_cg.py index d997dc2e61bd3..6af5ef177c387 100644 --- a/python/taichi/linalg/matrixfree_cg.py +++ b/python/taichi/linalg/matrixfree_cg.py @@ -17,6 +17,20 @@ def matvec(self, x, Ax): def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): + """Matrix-free conjugate-gradient solver. + + Use conjugate-gradient method to solve the linear system Ax = b, where A is implicitly + represented as a LinearOperator. + + Args: + A (LinearOperator): The coefficient matrix A of the linear system. + b (Field): The right-hand side of the linear system. + x (Field): The initial guess for the solution. + maxiter (int): Maximum number of iterations. + atol: Tolerance(absolute) for convergence. + quiet (bool): Switch to turn on/off iteration log. + """ + if b.dtype != x.dtype: raise TaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).") if str(b.dtype) == "f32": diff --git a/python/taichi/linalg/sparse_cg.py b/python/taichi/linalg/sparse_cg.py index fc0611fbff660..5ebdf663f5ca9 100644 --- a/python/taichi/linalg/sparse_cg.py +++ b/python/taichi/linalg/sparse_cg.py @@ -7,6 +7,18 @@ class SparseCG: + """Conjugate-gradient solver built for SparseMatrix. + + Use conjugate-gradient method to solve the linear system Ax = b, where A is SparseMatrix. + + Args: + A (SparseMatrix): The coefficient matrix A of the linear system. + b (numpy ndarray, taichi Ndarray): The right-hand side of the linear system. + x0 (numpy ndarray, taichi Ndarray): The initial guess for the solution. + max_iter (int): Maximum number of iterations. + atol: Tolerance(absolute) for convergence. + """ + def __init__(self, A, b, x0=None, max_iter=50, atol=1e-6): self.dtype = A.dtype self.ti_arch = get_runtime().prog.config().arch