Skip to content

Commit

Permalink
[lang] Implement experimental CG(Conjugate Gradient) solver in Taichi…
Browse files Browse the repository at this point in the history
…-lang (#7690)

Issue: #7634 

### Brief Summary
This PR implements a matrix-free CG (Conjugate-Gradient) solver in
Taichi. The solver targets to solve the linear equation system:

$$ Ax = b$$

where $A$ is implicitly represented as a `LinearOperator` instead of a
explicitly stored matrix, hence the name "matrix-free".

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
houkensjtu and pre-commit-ci[bot] authored Apr 7, 2023
1 parent 564a880 commit a7f7051
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/taichi/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from taichi.linalg.cg import CG
from taichi.linalg.sparse_matrix import *
from taichi.linalg.sparse_solver import SparseSolver
from taichi.linalg.taichi_cg import *
105 changes: 105 additions & 0 deletions python/taichi/linalg/taichi_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from math import sqrt

from taichi.lang.exception import TaichiRuntimeError, TaichiTypeError

import taichi as ti


@ti.data_oriented
class LinearOperator:
def __init__(self, matvec_kernel):
self._matvec = matvec_kernel

def matvec(self, x, Ax):
if x.shape != Ax.shape:
raise TaichiRuntimeError(
f"Dimension mismatch x.shape{x.shape} != Ax.shape{Ax.shape}.")
self._matvec(x, Ax)


def taichi_cg_solver(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
if b.dtype != x.dtype:
raise TaichiTypeError(
f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
if str(b.dtype) == 'f32':
solver_dtype = ti.f32
elif str(b.dtype) == 'f64':
solver_dtype = ti.f64
else:
raise TaichiTypeError(f"Not supported dtype: {b.dtype}")
if b.shape != x.shape:
raise TaichiRuntimeError(
f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")

size = b.shape
vector_fields_builder = ti.FieldsBuilder()
p = ti.field(dtype=solver_dtype)
r = ti.field(dtype=solver_dtype)
Ap = ti.field(dtype=solver_dtype)
vector_fields_builder.dense(ti.ij, size).place(p, r, Ap)
vector_fields_snode_tree = vector_fields_builder.finalize()

scalar_builder = ti.FieldsBuilder()
alpha = ti.field(dtype=solver_dtype)
beta = ti.field(dtype=solver_dtype)
scalar_builder.place(alpha, beta)
scalar_snode_tree = scalar_builder.finalize()

@ti.kernel
def init():
for I in ti.grouped(x):
r[I] = b[I]
p[I] = 0.0
Ap[I] = 0.0

@ti.kernel
def reduce(p: ti.template(), q: ti.template()) -> solver_dtype:
result = 0.0
for I in ti.grouped(p):
result += p[I] * q[I]
return result

@ti.kernel
def update_x():
for I in ti.grouped(x):
x[I] += alpha[None] * p[I]

@ti.kernel
def update_r():
for I in ti.grouped(r):
r[I] -= alpha[None] * Ap[I]

@ti.kernel
def update_p():
for I in ti.grouped(p):
p[I] = r[I] + beta[None] * p[I]

def solve():
init()
initial_rTr = reduce(r, r)
if not quiet:
print(f'>>> Initial residual = {initial_rTr:e}')
old_rTr = initial_rTr
update_p()
# -- Main loop --
for i in range(maxiter):
A._matvec(p, Ap) # compute Ap = A x p
pAp = reduce(p, Ap)
alpha[None] = old_rTr / pAp
update_x()
update_r()
new_rTr = reduce(r, r)
if sqrt(new_rTr) < tol:
if not quiet:
print('>>> Conjugate Gradient method converged.')
print(f'>>> #iterations {i}')
break
beta[None] = new_rTr / old_rTr
update_p()
old_rTr = new_rTr
if not quiet:
print(f'>>> Iter = {i+1:4}, Residual = {sqrt(new_rTr):e}')

solve()
vector_fields_snode_tree.destroy()
scalar_snode_tree.destroy()
57 changes: 57 additions & 0 deletions tests/python/test_taichi_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import math

import pytest
from taichi.linalg import LinearOperator, taichi_cg_solver

import taichi as ti
from tests import test_utils

vk_on_mac = (ti.vulkan, 'Darwin')


@pytest.mark.parametrize("ti_dtype", [ti.f32, ti.f64])
@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac])
def test_taichi_cg(ti_dtype):
GRID = 32
Ax = ti.field(dtype=ti_dtype, shape=(GRID, GRID))
x = ti.field(dtype=ti_dtype, shape=(GRID, GRID))
b = ti.field(dtype=ti_dtype, shape=(GRID, GRID))

@ti.kernel
def init():
for i, j in ti.ndrange(GRID, GRID):
xl = i / (GRID - 1)
yl = j / (GRID - 1)
b[i, j] = ti.sin(2 * math.pi * xl) * ti.sin(2 * math.pi * yl)
x[i, j] = 0.0

@ti.kernel
def compute_Ax(v: ti.template(), mv: ti.template()):
for i, j in v:
l = v[i - 1, j] if i - 1 >= 0 else 0.0
r = v[i + 1, j] if i + 1 <= GRID - 1 else 0.0
t = v[i, j + 1] if j + 1 <= GRID - 1 else 0.0
b = v[i, j - 1] if j - 1 >= 0 else 0.0
# Avoid ill-conditioned matrix A
mv[i, j] = 20 * v[i, j] - l - r - t - b

@ti.kernel
def check_solution(sol: ti.template(), ans: ti.template(),
tol: ti_dtype) -> bool:
exit_code = True
for i, j in ti.ndrange(GRID, GRID):
if ti.abs(ans[i, j] - sol[i, j]) < tol:
pass
else:
exit_code = False
return exit_code

A = LinearOperator(compute_Ax)
init()
taichi_cg_solver(A, b, x, maxiter=10 * GRID * GRID, tol=1e-18, quiet=True)
compute_Ax(x, Ax)
# `tol` can't be < 1e-6 for ti.f32 because of accumulating round-off error;
# see https://en.wikipedia.org/wiki/Conjugate_gradient_method#cite_note-6
# for more details.
result = check_solution(Ax, b, tol=1e-6)
assert result

0 comments on commit a7f7051

Please sign in to comment.