Skip to content

Commit

Permalink
add scipy converstion
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Mar 26, 2024
1 parent e51e26e commit daab9c3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
3 changes: 3 additions & 0 deletions lenskit/util/csmatrix.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_array

class CSMatrix:
nrows: int
Expand All @@ -19,3 +20,5 @@ class CSMatrix:
vs: npt.NDArray[np.float64],
): ...
def row_ep(self, row: int) -> tuple[int, int]: ...
@staticmethod
def from_scipy(matrix: csr_array) -> CSMatrix: ...
6 changes: 6 additions & 0 deletions lenskit/util/csmatrix.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ cdef class CSMatrix:
self.values = vs
self.nnz = self.rowptr[nr]

@staticmethod
def from_scipy(m):
nr, nc = m.shape

return CSMatrix(nr, nc, m.indptr, m.indices, m.data)

cpdef (int,int) row_ep(self, row):
if row < 0 or row >= self.nrows:
raise IndexError(f"invalid row {row} for {self.nrows}x{self.ncols} matrix")

Check warning on line 25 in lenskit/util/csmatrix.pyx

View check run for this annotation

Codecov / codecov/patch

lenskit/util/csmatrix.pyx#L25

Added line #L25 was not covered by tests
Expand Down
15 changes: 12 additions & 3 deletions tests/test_csmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,21 @@ def test_init_matrix(m: sps.csr_array):
assert m2.nnz == m.nnz


@given(sparse_matrices())
def test_from_scipy(m: sps.csr_array):
print(m.shape, m.nnz, m.indptr.dtype, m.indices.dtype)
m2 = CSMatrix.from_scipy(m)

assert m2.nrows == m.shape[0]
assert m2.ncols == m.shape[1]
assert m2.nnz == m.nnz


@given(sparse_matrices())
def test_csm_row_ep(m: sps.csr_array):
nr, nc = m.shape
m2 = CSMatrix(nr, nc, m.indptr, m.indices, m.data)
m2 = CSMatrix.from_scipy(m)

for i in range(nr):
for i in range(m2.nrows):
sp, ep = m2.row_ep(i)
assert sp == m2.rowptr[i]
assert ep == m2.rowptr[i + 1]

0 comments on commit daab9c3

Please sign in to comment.