From 3332eee22581535ad12cd96ade628658e02c6f96 Mon Sep 17 00:00:00 2001 From: Qian Bao Date: Wed, 31 May 2023 16:22:06 +0800 Subject: [PATCH] [bug] Fix MatrixFreeCG so it can handle multiple input sizes. (#8070) ### Brief Summary Previously, `MatrixFreeCG` inappropriately assume that the intermediate vectors used in CG solver are all two-dimensional `ti.field`. This is incorrect because the dimension of `Ap` `p` and `r` should be consistent with input `b` and `x`. ### Walkthrough Instead of simply put `vector_fields_builder.dense(ti.ij, size).place(p, r, Ap)`, now it's based on the length of `size`. ```python def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): ... vector_fields_builder = ti.FieldsBuilder() p = ti.field(dtype=solver_dtype) r = ti.field(dtype=solver_dtype) Ap = ti.field(dtype=solver_dtype) if len(size) == 1: # Determine the `axes` argument based on the length of `size` axes = ti.i elif len(size) == 2: axes = ti.ij elif len(size) == 3: axes = ti.ijk else: raise TaichiRuntimeError(f"MatrixFreeCG currently cannot support {len(size)}-D inputs.") vector_fields_builder.dense(axes, size).place(p, r, Ap) ``` --- python/taichi/linalg/matrixfree_cg.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/taichi/linalg/matrixfree_cg.py b/python/taichi/linalg/matrixfree_cg.py index 0d333f4bbcea0..22bb51a4fd70c 100644 --- a/python/taichi/linalg/matrixfree_cg.py +++ b/python/taichi/linalg/matrixfree_cg.py @@ -47,7 +47,15 @@ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): p = ti.field(dtype=solver_dtype) r = ti.field(dtype=solver_dtype) Ap = ti.field(dtype=solver_dtype) - vector_fields_builder.dense(ti.ij, size).place(p, r, Ap) + if len(size) == 1: + axes = ti.i + elif len(size) == 2: + axes = ti.ij + elif len(size) == 3: + axes = ti.ijk + else: + raise TaichiRuntimeError(f"MatrixFreeCG only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.") + vector_fields_builder.dense(axes, size).place(p, r, Ap) vector_fields_snode_tree = vector_fields_builder.finalize() scalar_builder = ti.FieldsBuilder()