From e3cccb8a098a68701759604d577ea2a896a6b026 Mon Sep 17 00:00:00 2001 From: Qian Bao Date: Wed, 31 May 2023 14:49:25 +0800 Subject: [PATCH] [bug] Fix SparseMatrix's dtype; check for dtype in SparseSolver. (#8071) Issue: #8045 ### Brief Summary `SparseMatrix`'s dtype should be determined by the dtype of the `SparseMatrixBuilder` from which it's built. However, it's fixed to `f32` in the current code base. This PR fix this by passing `self.dtype` instead of `dtype` to `SparseMatrix()` in the builder: ```python def build(self, dtype=f32, _format="CSR"): """Create a sparse matrix using the triplets""" taichi_arch = get_runtime().prog.config().arch if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]: sm = self.ptr.build() return SparseMatrix(sm=sm, dtype=self.dtype) # Previously it was dtype, which is f32 in the current context. ``` Also, `SparseSolver` should raise an exception to better notify the user if the `dtype` of the `SparseMatrix` is not consistent with the solver's `dtype`. This is implemented in the `sparse_solver.py`. ### Additional comments - This PR should resolve the user's question posted on the forum: https://forum.taichi-lang.cn/t/topic/4316 - This PR might fix part of the CI test failures mentioned in issue #8045 , but #8045 still needs more thorough investigation even with this PR merged to ensure that there is no future issues. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/linalg/sparse_matrix.py | 6 +++--- python/taichi/linalg/sparse_solver.py | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/taichi/linalg/sparse_matrix.py b/python/taichi/linalg/sparse_matrix.py index 443afd933aae3..691f76b1cb912 100644 --- a/python/taichi/linalg/sparse_matrix.py +++ b/python/taichi/linalg/sparse_matrix.py @@ -284,12 +284,12 @@ def build(self, dtype=f32, _format="CSR"): taichi_arch = get_runtime().prog.config().arch if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]: sm = self.ptr.build() - return SparseMatrix(sm=sm, dtype=dtype) + return SparseMatrix(sm=sm, dtype=self.dtype) if taichi_arch == _ti_core.Arch.cuda: - if dtype != f32: + if self.dtype != f32: raise TaichiRuntimeError("CUDA sparse matrix only supports f32.") sm = self.ptr.build_cuda() - return SparseMatrix(sm=sm, dtype=dtype) + return SparseMatrix(sm=sm, dtype=self.dtype) raise TaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.") diff --git a/python/taichi/linalg/sparse_solver.py b/python/taichi/linalg/sparse_solver.py index 05775c0dda496..3850cbee227f4 100644 --- a/python/taichi/linalg/sparse_solver.py +++ b/python/taichi/linalg/sparse_solver.py @@ -21,6 +21,7 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None + self.dtype = dtype solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -70,6 +71,10 @@ def analyze_pattern(self, sparse_matrix): """ if isinstance(sparse_matrix, SparseMatrix): self.matrix = sparse_matrix + if self.matrix.dtype != self.dtype: + raise TaichiRuntimeError( + f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}." + ) self.solver.analyze_pattern(sparse_matrix.matrix) else: self._type_assert(sparse_matrix)