Skip to content

Commit

Permalink
Add support for SciPy CSC and CSR sparse types to Numba
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 7, 2022
1 parent 0d28d0b commit 75f09bf
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
1 change: 1 addition & 0 deletions aesara/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
import aesara.link.numba.dispatch.random
import aesara.link.numba.dispatch.elemwise
import aesara.link.numba.dispatch.scan
import aesara.link.numba.dispatch.sparse

# isort: on
142 changes: 142 additions & 0 deletions aesara/link/numba/dispatch/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import scipy as sp
import scipy.sparse
from numba.core import cgutils, types
from numba.extending import (
NativeValue,
box,
make_attribute_wrapper,
models,
register_model,
typeof_impl,
unbox,
)


class CSMatrixType(types.Type):
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""

name: str
instance_class: type

def __init__(self, dtype):
self.dtype = dtype
self.data = types.Array(dtype, 1, "A")
self.indices = types.Array(types.int32, 1, "A")
self.indptr = types.Array(types.int32, 1, "A")
self.shape = types.UniTuple(types.int64, 2)
super().__init__(self.name)


make_attribute_wrapper(CSMatrixType, "data", "data")
make_attribute_wrapper(CSMatrixType, "indices", "indices")
make_attribute_wrapper(CSMatrixType, "indptr", "indptr")
make_attribute_wrapper(CSMatrixType, "shape", "shape")


class CSRMatrixType(CSMatrixType):
name = "csr_matrix"

@staticmethod
def instance_class(data, indices, indptr, shape):
return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False)


class CSCMatrixType(CSMatrixType):
name = "csc_matrix"

@staticmethod
def instance_class(data, indices, indptr, shape):
return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False)


@typeof_impl.register(sp.sparse.csc_matrix)
def typeof_csc_matrix(val, c):
data = typeof_impl(val.data, c)
return CSCMatrixType(data.dtype)


@typeof_impl.register(sp.sparse.csr_matrix)
def typeof_csr_matrix(val, c):
data = typeof_impl(val.data, c)
return CSRMatrixType(data.dtype)


@register_model(CSRMatrixType)
class CSRMatrixModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
("data", fe_type.data),
("indices", fe_type.indices),
("indptr", fe_type.indptr),
("shape", fe_type.shape),
]
super().__init__(dmm, fe_type, members)


@register_model(CSCMatrixType)
class CSCMatrixModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
("data", fe_type.data),
("indices", fe_type.indices),
("indptr", fe_type.indptr),
("shape", fe_type.shape),
]
super().__init__(dmm, fe_type, members)


@unbox(CSCMatrixType)
@unbox(CSRMatrixType)
def unbox_matrix(typ, obj, c):

struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder)

data = c.pyapi.object_getattr_string(obj, "data")
indices = c.pyapi.object_getattr_string(obj, "indices")
indptr = c.pyapi.object_getattr_string(obj, "indptr")
shape = c.pyapi.object_getattr_string(obj, "shape")

struct_ptr.data = c.unbox(typ.data, data).value
struct_ptr.indices = c.unbox(typ.indices, indices).value
struct_ptr.indptr = c.unbox(typ.indptr, indptr).value
struct_ptr.shape = c.unbox(typ.shape, shape).value

c.pyapi.decref(data)
c.pyapi.decref(indices)
c.pyapi.decref(indptr)
c.pyapi.decref(shape)

is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
is_error = c.builder.load(is_error_ptr)

res = NativeValue(struct_ptr._getvalue(), is_error=is_error)

return res


@box(CSCMatrixType)
@box(CSRMatrixType)
def box_matrix(typ, val, c):
struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)

data_obj = c.box(typ.data, struct_ptr.data)
indices_obj = c.box(typ.indices, struct_ptr.indices)
indptr_obj = c.box(typ.indptr, struct_ptr.indptr)
shape_obj = c.box(typ.shape, struct_ptr.shape)

c.pyapi.incref(data_obj)
c.pyapi.incref(indices_obj)
c.pyapi.incref(indptr_obj)
c.pyapi.incref(shape_obj)

cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class))
obj = c.pyapi.call_function_objargs(
cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj)
)

c.pyapi.decref(data_obj)
c.pyapi.decref(indices_obj)
c.pyapi.decref(indptr_obj)
c.pyapi.decref(shape_obj)

return obj
40 changes: 40 additions & 0 deletions tests/link/numba/test_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numba
import numpy as np
import scipy as sp

# Load Numba customizations
import aesara.link.numba.dispatch.sparse # noqa: F401


def test_sparse_unboxing():
@numba.njit
def test_unboxing(x, y):
return x.shape, y.shape

x_val = sp.sparse.csr_matrix(np.eye(100))
y_val = sp.sparse.csc_matrix(np.eye(101))

res = test_unboxing(x_val, y_val)

assert res == (x_val.shape, y_val.shape)


def test_sparse_boxing():
@numba.njit
def test_boxing(x, y):
return x, y

x_val = sp.sparse.csr_matrix(np.eye(100))
y_val = sp.sparse.csc_matrix(np.eye(101))

res_x_val, res_y_val = test_boxing(x_val, y_val)

assert np.array_equal(res_x_val.data, x_val.data)
assert np.array_equal(res_x_val.indices, x_val.indices)
assert np.array_equal(res_x_val.indptr, x_val.indptr)
assert res_x_val.shape == x_val.shape

assert np.array_equal(res_y_val.data, y_val.data)
assert np.array_equal(res_y_val.indices, y_val.indices)
assert np.array_equal(res_y_val.indptr, y_val.indptr)
assert res_y_val.shape == y_val.shape

0 comments on commit 75f09bf

Please sign in to comment.