diff --git a/python/taichi/linalg/__init__.py b/python/taichi/linalg/__init__.py index 2bb594ffe273f..c4e28878b1b59 100644 --- a/python/taichi/linalg/__init__.py +++ b/python/taichi/linalg/__init__.py @@ -1,6 +1,6 @@ """Taichi support module for sparse matrix operations. """ -from taichi.linalg.cg import CG +from taichi.linalg.sparse_cg import SparseCG from taichi.linalg.sparse_matrix import * from taichi.linalg.sparse_solver import SparseSolver -from taichi.linalg.taichi_cg import * +from taichi.linalg.matrixfree_cg import * diff --git a/python/taichi/linalg/taichi_cg.py b/python/taichi/linalg/matrixfree_cg.py similarity index 82% rename from python/taichi/linalg/taichi_cg.py rename to python/taichi/linalg/matrixfree_cg.py index bb39c0af734f0..0d333f4bbcea0 100644 --- a/python/taichi/linalg/taichi_cg.py +++ b/python/taichi/linalg/matrixfree_cg.py @@ -16,7 +16,21 @@ def matvec(self, x, Ax): self._matvec(x, Ax) -def taichi_cg_solver(A, b, x, tol=1e-6, maxiter=5000, quiet=True): +def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): + """Matrix-free conjugate-gradient solver. + + Use conjugate-gradient method to solve the linear system Ax = b, where A is implicitly + represented as a LinearOperator. + + Args: + A (LinearOperator): The coefficient matrix A of the linear system. + b (Field): The right-hand side of the linear system. + x (Field): The initial guess for the solution. + maxiter (int): Maximum number of iterations. + atol: Tolerance(absolute) for convergence. + quiet (bool): Switch to turn on/off iteration log. + """ + if b.dtype != x.dtype: raise TaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).") if str(b.dtype) == "f32": diff --git a/python/taichi/linalg/cg.py b/python/taichi/linalg/sparse_cg.py similarity index 78% rename from python/taichi/linalg/cg.py rename to python/taichi/linalg/sparse_cg.py index 21973641e4db3..e0aba8112c690 100644 --- a/python/taichi/linalg/cg.py +++ b/python/taichi/linalg/sparse_cg.py @@ -6,7 +6,19 @@ from taichi.types import f32, f64 -class CG: +class SparseCG: + """Conjugate-gradient solver built for SparseMatrix. + + Use conjugate-gradient method to solve the linear system Ax = b, where A is SparseMatrix. + + Args: + A (SparseMatrix): The coefficient matrix A of the linear system. + b (numpy ndarray, taichi Ndarray): The right-hand side of the linear system. + x0 (numpy ndarray, taichi Ndarray): The initial guess for the solution. + max_iter (int): Maximum number of iterations. + atol: Tolerance(absolute) for convergence. + """ + def __init__(self, A, b, x0=None, max_iter=50, atol=1e-6): self.dtype = A.dtype self.ti_arch = get_runtime().prog.config().arch diff --git a/tests/python/test_taichi_cg.py b/tests/python/test_matrixfree_cg.py similarity index 91% rename from tests/python/test_taichi_cg.py rename to tests/python/test_matrixfree_cg.py index 122669931aa62..733a0ea24c0d2 100644 --- a/tests/python/test_taichi_cg.py +++ b/tests/python/test_matrixfree_cg.py @@ -1,7 +1,7 @@ import math import pytest -from taichi.linalg import LinearOperator, taichi_cg_solver +from taichi.linalg import LinearOperator, MatrixFreeCG import taichi as ti from tests import test_utils @@ -11,7 +11,7 @@ @pytest.mark.parametrize("ti_dtype", [ti.f32, ti.f64]) @test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac]) -def test_taichi_cg(ti_dtype): +def test_matrixfree_cg(ti_dtype): GRID = 32 Ax = ti.field(dtype=ti_dtype, shape=(GRID, GRID)) x = ti.field(dtype=ti_dtype, shape=(GRID, GRID)) @@ -47,7 +47,7 @@ def check_solution(sol: ti.template(), ans: ti.template(), tol: ti_dtype) -> boo A = LinearOperator(compute_Ax) init() - taichi_cg_solver(A, b, x, maxiter=10 * GRID * GRID, tol=1e-18, quiet=True) + MatrixFreeCG(A, b, x, maxiter=10 * GRID * GRID, tol=1e-18, quiet=True) compute_Ax(x, Ax) # `tol` can't be < 1e-6 for ti.f32 because of accumulating round-off error; # see https://en.wikipedia.org/wiki/Conjugate_gradient_method#cite_note-6 diff --git a/tests/python/test_cg.py b/tests/python/test_sparse_cg.py similarity index 93% rename from tests/python/test_cg.py rename to tests/python/test_sparse_cg.py index bd0eb38445f20..ad6afa4c071f4 100644 --- a/tests/python/test_cg.py +++ b/tests/python/test_sparse_cg.py @@ -28,7 +28,7 @@ def fill( fill(Abuilder, A_psd, b) A = Abuilder.build(dtype=ti_dtype) - cg = ti.linalg.CG(A, b, x0, max_iter=50, atol=1e-6) + cg = ti.linalg.SparseCG(A, b, x0, max_iter=50, atol=1e-6) x, exit_code = cg.solve() res = np.linalg.solve(A_psd, b.to_numpy()) assert exit_code == True @@ -59,7 +59,7 @@ def fill( fill(Abuilder, A_psd, b) A = Abuilder.build(dtype=ti_dtype) - cg = ti.linalg.CG(A, b, x0, max_iter=50, atol=1e-6) + cg = ti.linalg.SparseCG(A, b, x0, max_iter=50, atol=1e-6) x, exit_code = cg.solve() res = np.linalg.solve(A_psd, b.to_numpy()) assert exit_code == True