diff --git a/python/taichi/linalg/sparse_matrix.py b/python/taichi/linalg/sparse_matrix.py index 443afd933aae3..691f76b1cb912 100644 --- a/python/taichi/linalg/sparse_matrix.py +++ b/python/taichi/linalg/sparse_matrix.py @@ -284,12 +284,12 @@ def build(self, dtype=f32, _format="CSR"): taichi_arch = get_runtime().prog.config().arch if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]: sm = self.ptr.build() - return SparseMatrix(sm=sm, dtype=dtype) + return SparseMatrix(sm=sm, dtype=self.dtype) if taichi_arch == _ti_core.Arch.cuda: - if dtype != f32: + if self.dtype != f32: raise TaichiRuntimeError("CUDA sparse matrix only supports f32.") sm = self.ptr.build_cuda() - return SparseMatrix(sm=sm, dtype=dtype) + return SparseMatrix(sm=sm, dtype=self.dtype) raise TaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.") diff --git a/python/taichi/linalg/sparse_solver.py b/python/taichi/linalg/sparse_solver.py index 05775c0dda496..3850cbee227f4 100644 --- a/python/taichi/linalg/sparse_solver.py +++ b/python/taichi/linalg/sparse_solver.py @@ -21,6 +21,7 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None + self.dtype = dtype solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -70,6 +71,10 @@ def analyze_pattern(self, sparse_matrix): """ if isinstance(sparse_matrix, SparseMatrix): self.matrix = sparse_matrix + if self.matrix.dtype != self.dtype: + raise TaichiRuntimeError( + f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}." + ) self.solver.analyze_pattern(sparse_matrix.matrix) else: self._type_assert(sparse_matrix)