Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Refactor cg solvers #7911

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/taichi/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Taichi support module for sparse matrix operations.
"""
from taichi.linalg.cg import CG
from taichi.linalg.sparse_cg import SparseCG
from taichi.linalg.sparse_matrix import *
from taichi.linalg.sparse_solver import SparseSolver
from taichi.linalg.taichi_cg import *
from taichi.linalg.matrixfree_cg import *
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,21 @@ def matvec(self, x, Ax):
self._matvec(x, Ax)


def taichi_cg_solver(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
"""Matrix-free conjugate-gradient solver.

Use conjugate-gradient method to solve the linear system Ax = b, where A is implicitly
represented as a LinearOperator.

Args:
A (LinearOperator): The coefficient matrix A of the linear system.
b (Field): The right-hand side of the linear system.
x (Field): The initial guess for the solution.
maxiter (int): Maximum number of iterations.
atol: Tolerance(absolute) for convergence.
quiet (bool): Switch to turn on/off iteration log.
"""

if b.dtype != x.dtype:
raise TaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
if str(b.dtype) == "f32":
Expand Down
14 changes: 13 additions & 1 deletion python/taichi/linalg/cg.py → python/taichi/linalg/sparse_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@
from taichi.types import f32, f64


class CG:
class SparseCG:
houkensjtu marked this conversation as resolved.
Show resolved Hide resolved
"""Conjugate-gradient solver built for SparseMatrix.

Use conjugate-gradient method to solve the linear system Ax = b, where A is SparseMatrix.

Args:
A (SparseMatrix): The coefficient matrix A of the linear system.
b (numpy ndarray, taichi Ndarray): The right-hand side of the linear system.
x0 (numpy ndarray, taichi Ndarray): The initial guess for the solution.
max_iter (int): Maximum number of iterations.
atol: Tolerance(absolute) for convergence.
"""

def __init__(self, A, b, x0=None, max_iter=50, atol=1e-6):
self.dtype = A.dtype
self.ti_arch = get_runtime().prog.config().arch
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

import pytest
from taichi.linalg import LinearOperator, taichi_cg_solver
from taichi.linalg import LinearOperator, MatrixFreeCG

import taichi as ti
from tests import test_utils
Expand All @@ -11,7 +11,7 @@

@pytest.mark.parametrize("ti_dtype", [ti.f32, ti.f64])
@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac])
def test_taichi_cg(ti_dtype):
def test_matrixfree_cg(ti_dtype):
GRID = 32
Ax = ti.field(dtype=ti_dtype, shape=(GRID, GRID))
x = ti.field(dtype=ti_dtype, shape=(GRID, GRID))
Expand Down Expand Up @@ -47,7 +47,7 @@ def check_solution(sol: ti.template(), ans: ti.template(), tol: ti_dtype) -> boo

A = LinearOperator(compute_Ax)
init()
taichi_cg_solver(A, b, x, maxiter=10 * GRID * GRID, tol=1e-18, quiet=True)
MatrixFreeCG(A, b, x, maxiter=10 * GRID * GRID, tol=1e-18, quiet=True)
compute_Ax(x, Ax)
# `tol` can't be < 1e-6 for ti.f32 because of accumulating round-off error;
# see https://en.wikipedia.org/wiki/Conjugate_gradient_method#cite_note-6
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_cg.py → tests/python/test_sparse_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def fill(

fill(Abuilder, A_psd, b)
A = Abuilder.build(dtype=ti_dtype)
cg = ti.linalg.CG(A, b, x0, max_iter=50, atol=1e-6)
cg = ti.linalg.SparseCG(A, b, x0, max_iter=50, atol=1e-6)
x, exit_code = cg.solve()
res = np.linalg.solve(A_psd, b.to_numpy())
assert exit_code == True
Expand Down Expand Up @@ -59,7 +59,7 @@ def fill(

fill(Abuilder, A_psd, b)
A = Abuilder.build(dtype=ti_dtype)
cg = ti.linalg.CG(A, b, x0, max_iter=50, atol=1e-6)
cg = ti.linalg.SparseCG(A, b, x0, max_iter=50, atol=1e-6)
x, exit_code = cg.solve()
res = np.linalg.solve(A_psd, b.to_numpy())
assert exit_code == True
Expand Down