Skip to content

Commit

Permalink
fix naming convention
Browse files Browse the repository at this point in the history
  • Loading branch information
FantasyVR committed Nov 14, 2022
1 parent 97ade0b commit a034a5b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SparseSolver:
ordering (str): The method for matrices re-ordering.
"""
def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
self.matrix_ = None
self.matrix = None
solver_type_list = ["LLT", "LDLT", "LU"]
solver_ordering = ['AMD', 'COLAMD']
if solver_type in solver_type_list and ordering in solver_ordering:
Expand Down Expand Up @@ -49,13 +49,13 @@ def compute(self, sparse_matrix):
sparse_matrix (SparseMatrix): The sparse matrix to be computed.
"""
if isinstance(sparse_matrix, SparseMatrix):
self.matrix_ = sparse_matrix
self.matrix = sparse_matrix
taichi_arch = taichi.lang.impl.get_runtime().prog.config().arch
if taichi_arch == _ti_core.Arch.x64:
self.solver.compute(sparse_matrix.matrix)
elif taichi_arch == _ti_core.Arch.cuda:
self.analyze_pattern(self.matrix_)
self.factorize(self.matrix_)
self.analyze_pattern(self.matrix)
self.factorize(self.matrix)
else:
self._type_assert(sparse_matrix)

Expand All @@ -66,7 +66,7 @@ def analyze_pattern(self, sparse_matrix):
sparse_matrix (SparseMatrix): The sparse matrix to be analyzed.
"""
if isinstance(sparse_matrix, SparseMatrix):
self.matrix_ = sparse_matrix
self.matrix = sparse_matrix
self.solver.analyze_pattern(sparse_matrix.matrix)
else:
self._type_assert(sparse_matrix)
Expand All @@ -78,7 +78,7 @@ def factorize(self, sparse_matrix):
sparse_matrix (SparseMatrix): The sparse matrix to be factorized.
"""
if isinstance(sparse_matrix, SparseMatrix):
self.matrix_ = sparse_matrix
self.matrix = sparse_matrix
self.solver.factorize(sparse_matrix.matrix)
else:
self._type_assert(sparse_matrix)
Expand All @@ -91,16 +91,16 @@ def solve(self, b): # pylint: disable=R1710
Returns:
numpy.array: The solution of linear systems.
"""
if self.matrix_ is None:
if self.matrix is None:
raise TaichiRuntimeError(
"Please call compute() before calling solve().")
if isinstance(b, Field):
return self.solver.solve(b.to_numpy())
if isinstance(b, np.ndarray):
return self.solver.solve(b)
if isinstance(b, Ndarray):
x = ScalarNdarray(b.dtype, [self.matrix_.m])
self.solve_rf(self.matrix_, b, x)
x = ScalarNdarray(b.dtype, [self.matrix.m])
self.solve_rf(self.matrix, b, x)
return x
raise TaichiRuntimeError(
f"The parameter type: {type(b)} is not supported in linear solvers for now."
Expand Down

0 comments on commit a034a5b

Please sign in to comment.