Skip to content

Commit

Permalink
unify spmv api
Browse files Browse the repository at this point in the history
  • Loading branch information
FantasyVR committed Nov 10, 2022
1 parent e371291 commit 8bdb24b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 64 deletions.
2 changes: 1 addition & 1 deletion misc/test_build_cusm_from_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
A = ti.linalg.SparseMatrix(n=4, m=4, dtype=ti.float32)
A.build_coo(d_coo_row, d_coo_col, d_coo_val)

A.spmv(x, y)
y = A @ x

# Check if the results are correct
equal = True
Expand Down
36 changes: 0 additions & 36 deletions misc/test_sm.py

This file was deleted.

34 changes: 8 additions & 26 deletions python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def __matmul__(self, other):
assert self.m == other.shape[
0], f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
return self.matrix.mat_vec_mul(other)
if isinstance(other, Ndarray):
if self.m != other.shape[0]:
raise TaichiRuntimeError(
f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
)
res = ScalarNdarray(dtype=other.dtype, arr_shape=(self.n, ))
self.matrix.spmv(get_runtime().prog, other.arr, res.arr)
return res
raise TaichiRuntimeError(
f"Sparse matrix-matrix/vector multiplication does not support {type(other)} for now. Supported types are SparseMatrix, ti.field, and numpy ndarray."
)
Expand Down Expand Up @@ -222,32 +230,6 @@ def build_coo(self, row_coo, col_coo, value_coo):
get_runtime().prog.make_sparse_matrix_from_ndarray_cusparse(
self.matrix, row_coo.arr, col_coo.arr, value_coo.arr)

def spmv(self, x):
"""Sparse matrix-vector multiplication using cuSparse.
Args:
x (ti.ndarray): the vector to be multiplied.
y (ti.ndarray): the result of matrix-vector multiplication.
Example::
>>> x = ti.ndarray(shape=4, dtype=val_dt)
>>> y = ti.ndarray(shape=4, dtype=val_dt)
>>> A = ti.linalg.SparseMatrix(n=4, m=4, dtype=ti.f32)
>>> A.build_from_ndarray_cusparse(row_csr, col_csr, value_csr)
>>> A.spmv(x, y)
"""
if not isinstance(x, Ndarray):
raise TaichiRuntimeError(
'Sparse matrix only supports building from [ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]'
)
if self.m != x.shape[0]:
raise TaichiRuntimeError(
f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({x.shape})"
)
res = ScalarNdarray(dtype=x.dtype, arr_shape=(self.n, ))
self.matrix.spmv(get_runtime().prog, x.arr, res.arr)
return res


class SparseMatrixBuilder:
"""A python wrap around sparse matrix builder.
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def test_gpu_sparse_matrix():
A.build_coo(d_coo_row, d_coo_col, d_coo_val)

# Compute Y = A @ X
A.spmv(X, Y)
Y = A @ X
for i in range(4):
assert Y[i] == h_Y[i]

Expand Down

0 comments on commit 8bdb24b

Please sign in to comment.