Skip to content

Commit

Permalink
Add __binsparse_descriptor__ and __binsparse_dlpack__.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Sep 2, 2024
1 parent fb0affe commit 5aba8c4
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 0 deletions.
58 changes: 58 additions & 0 deletions sparse/numba_backend/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,15 @@ def isinf(self):
def isnan(self):
return self.tocoo().isnan().asformat("gcxs", compressed_axes=self.compressed_axes)

# `GCXS` is a reshaped/transposed `CSR`, but it can't (usually)
# be expressed in the `binsparse` 0.1 language.
# We are missing index maps.
def __binsparse_descriptor__(self) -> dict:
return super().__binsparse_descriptor__()

Check warning on line 851 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L851

Added line #L851 was not covered by tests

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return super().__binsparse_dlpack__()

Check warning on line 854 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L854

Added line #L854 was not covered by tests


class _Compressed2d(GCXS):
class_compressed_axes: tuple[int]
Expand Down Expand Up @@ -883,6 +892,34 @@ def from_numpy(cls, x, fill_value=0, idx_dtype=None):
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype)

def __binsparse_descriptor__(self) -> dict:
from sparse._version import __version__

Check warning on line 896 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L896

Added line #L896 was not covered by tests

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[{self.data.dtype.itemsize // 2}]"
return {

Check warning on line 901 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L898-L901

Added lines #L898 - L901 were not covered by tests
"binsparse": {
"version": "0.1",
"format": self.format.upper(),
"shape": list(self.shape),
"number_of_stored_values": self.nnz,
"data_types": {
"pointers_to_1": str(self.indices.dtype),
"indices_1": str(self.indptr.dtype),
"values": data_dt,
},
},
"original_source": f"`sparse`, version {__version__}",
}

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return {

Check warning on line 917 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L917

Added line #L917 was not covered by tests
"pointers_to_1": self.indices,
"indices_1": self.indptr,
"values": self.data,
}


class CSR(_Compressed2d):
"""
Expand Down Expand Up @@ -915,6 +952,27 @@ def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"
return self
return CSC((self.data, self.indices, self.indptr), self.shape[::-1])

def __binsparse_descriptor__(self) -> dict:
from sparse._version import __version__

Check warning on line 956 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L956

Added line #L956 was not covered by tests

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[{self.data.dtype.itemsize // 2}]"
return {

Check warning on line 961 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L958-L961

Added lines #L958 - L961 were not covered by tests
"binsparse": {
"version": "0.1",
"format": "CSR",
"shape": list(self.shape),
"number_of_stored_values": self.nnz,
"data_types": {
"pointers_to_1": str(self.indices.dtype),
"indices_1": str(self.indptr.dtype),
"values": data_dt,
},
},
"original_source": f"`sparse`, version {__version__}",
}


class CSC(_Compressed2d):
"""
Expand Down
38 changes: 38 additions & 0 deletions sparse/numba_backend/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,44 @@ def isnan(self):
prune=True,
)

def __binsparse_descriptor__(self) -> dict:
from sparse._version import __version__

Check warning on line 1541 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1541

Added line #L1541 was not covered by tests

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[{self.data.dtype.itemsize // 2}]"
return {

Check warning on line 1546 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1543-L1546

Added lines #L1543 - L1546 were not covered by tests
"binsparse": {
"version": "0.1",
"format": {
"custom": {
"level": {
"level_desc": "sparse",
"rank": self.ndim,
"level": {
"level_desc": "element",
},
}
}
},
"shape": list(self.shape),
"number_of_stored_values": self.nnz,
"data_types": {
"pointers_to_1": "uint8",
"indices_1": str(self.coords.dtype),
"values": data_dt,
},
},
"original_source": f"`sparse`, version {__version__}",
}

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return {

Check warning on line 1572 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1572

Added line #L1572 was not covered by tests
"pointers_to_1": np.array([0, self.nnz], dtype=np.uint8),
"indices_1": self.coords,
"values": self.data,
}


def as_coo(x, shape=None, fill_value=None, idx_dtype=None):
"""
Expand Down
6 changes: 6 additions & 0 deletions sparse/numba_backend/_dok.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,12 @@ def reshape(self, shape, order="C"):

return DOK.from_coo(self.to_coo().reshape(shape))

def __binsparse_descriptor__(self) -> dict:
raise RuntimeError("`DOK` doesn't support the `__binsparse_descriptor__` protocol.")

Check warning on line 552 in sparse/numba_backend/_dok.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_dok.py#L552

Added line #L552 was not covered by tests

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
raise RuntimeError("`DOK` doesn't support the `__binsparse_dlpack__` protocol.")

Check warning on line 555 in sparse/numba_backend/_dok.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_dok.py#L555

Added line #L555 was not covered by tests


def to_slice(k):
"""Convert integer indices to one-element slices for consistency"""
Expand Down
25 changes: 25 additions & 0 deletions sparse/numba_backend/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,31 @@ def _str_impl(self, summary):
except (ImportError, ValueError):
return summary

@abstractmethod
def __binsparse_descriptor__(self) -> dict:
"""Return a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor)
of this array.
Returns
-------
dict
Parsed `binsparse` descriptor.
"""
raise NotImplementedError

@abstractmethod
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
"""A `dict` containing the constituent arrays of this sparse array. The keys are compatible with the
[`binsparse`](https://graphblas.org/binsparse-specification/) scheme, and the values are [`__dlpack__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html)
compatible objects.
Returns
-------
dict[str, np.ndarray]
The constituent arrays.
"""
raise NotImplementedError

@abstractmethod
def asformat(self, format):
"""
Expand Down

0 comments on commit 5aba8c4

Please sign in to comment.