From daab9c3656fbbbc861b85b3549b5cf73954644a3 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Tue, 26 Mar 2024 17:15:25 +0000 Subject: [PATCH] add scipy converstion --- lenskit/util/csmatrix.pyi | 3 +++ lenskit/util/csmatrix.pyx | 6 ++++++ tests/test_csmatrix.py | 15 ++++++++++++--- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/lenskit/util/csmatrix.pyi b/lenskit/util/csmatrix.pyi index 7fbabcc80..33c4b702e 100644 --- a/lenskit/util/csmatrix.pyi +++ b/lenskit/util/csmatrix.pyi @@ -1,5 +1,6 @@ import numpy as np import numpy.typing as npt +from scipy.sparse import csr_array class CSMatrix: nrows: int @@ -19,3 +20,5 @@ class CSMatrix: vs: npt.NDArray[np.float64], ): ... def row_ep(self, row: int) -> tuple[int, int]: ... + @staticmethod + def from_scipy(matrix: csr_array) -> CSMatrix: ... diff --git a/lenskit/util/csmatrix.pyx b/lenskit/util/csmatrix.pyx index 3e236cdf4..a0ac84ad4 100644 --- a/lenskit/util/csmatrix.pyx +++ b/lenskit/util/csmatrix.pyx @@ -14,6 +14,12 @@ cdef class CSMatrix: self.values = vs self.nnz = self.rowptr[nr] + @staticmethod + def from_scipy(m): + nr, nc = m.shape + + return CSMatrix(nr, nc, m.indptr, m.indices, m.data) + cpdef (int,int) row_ep(self, row): if row < 0 or row >= self.nrows: raise IndexError(f"invalid row {row} for {self.nrows}x{self.ncols} matrix") diff --git a/tests/test_csmatrix.py b/tests/test_csmatrix.py index 1a25a124e..3de548382 100644 --- a/tests/test_csmatrix.py +++ b/tests/test_csmatrix.py @@ -42,12 +42,21 @@ def test_init_matrix(m: sps.csr_array): assert m2.nnz == m.nnz +@given(sparse_matrices()) +def test_from_scipy(m: sps.csr_array): + print(m.shape, m.nnz, m.indptr.dtype, m.indices.dtype) + m2 = CSMatrix.from_scipy(m) + + assert m2.nrows == m.shape[0] + assert m2.ncols == m.shape[1] + assert m2.nnz == m.nnz + + @given(sparse_matrices()) def test_csm_row_ep(m: sps.csr_array): - nr, nc = m.shape - m2 = CSMatrix(nr, nc, m.indptr, m.indices, m.data) + m2 = CSMatrix.from_scipy(m) - for i in range(nr): + for i in range(m2.nrows): sp, ep = m2.row_ep(i) assert sp == m2.rowptr[i] assert ep == m2.rowptr[i + 1]