Skip to content

Commit

Permalink
Only use functions in the limited API (rapidsai#2282)
Browse files Browse the repository at this point in the history
This PR removes usage of the only method in raft's Cython that is not part of the Python limited API. Contributes to rapidsai/build-planning#42

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#2282
  • Loading branch information
vyasr authored May 7, 2024
1 parent 19842a2 commit ef28628
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 33 deletions.
34 changes: 14 additions & 20 deletions python/pylibraft/pylibraft/common/mdspan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io

import numpy as np

from cpython.buffer cimport PyBUF_FULL_RO, PyBuffer_Release, PyObject_GetBuffer
from cpython.object cimport PyObject
from cython.operator cimport dereference as deref
from libc.stddef cimport size_t
Expand All @@ -47,10 +48,6 @@ from pylibraft.common.optional cimport make_optional, optional
from pylibraft.common import DeviceResources


cdef extern from "Python.h":
Py_buffer* PyMemoryView_GET_BUFFER(PyObject* mview)


def run_roundtrip_test_for_mdspan(X, fortran_order=False):
if not isinstance(X, np.ndarray) or len(X.shape) != 2:
raise ValueError("Please call this function with a NumPy array with"
Expand All @@ -59,6 +56,9 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
cdef device_resources * handle_ = \
<device_resources *> <size_t> handle.getHandle()
cdef ostringstream oss
cdef Py_buffer buf
PyObject_GetBuffer(X, &buf, PyBUF_FULL_RO)
cdef uintptr_t buf_ptr = <uintptr_t>buf.buf
if X.dtype == np.float32:
if fortran_order:
serialize_mdspan[float, matrix_extent[size_t], col_major](
Expand All @@ -67,8 +67,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[float, matrix_extent[size_t],
col_major] &>
make_host_matrix_view[float, size_t, col_major](
<float *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<float *>buf_ptr,
X.shape[0], X.shape[1]))
else:
serialize_mdspan[float, matrix_extent[size_t], row_major](
Expand All @@ -77,8 +76,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[float, matrix_extent[size_t],
row_major]&>
make_host_matrix_view[float, size_t, row_major](
<float *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<float *>buf_ptr,
X.shape[0], X.shape[1]))
elif X.dtype == np.float64:
if fortran_order:
Expand All @@ -88,8 +86,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[double, matrix_extent[size_t],
col_major]&>
make_host_matrix_view[double, size_t, col_major](
<double *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<double *>buf_ptr,
X.shape[0], X.shape[1]))
else:
serialize_mdspan[double, matrix_extent[size_t], row_major](
Expand All @@ -98,8 +95,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[double, matrix_extent[size_t],
row_major]&>
make_host_matrix_view[double, size_t, row_major](
<double *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<double *>buf_ptr,
X.shape[0], X.shape[1]))
elif X.dtype == np.int32:
if fortran_order:
Expand All @@ -109,8 +105,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[int32_t, matrix_extent[size_t],
col_major]&>
make_host_matrix_view[int32_t, size_t, col_major](
<int32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<int32_t *>buf_ptr,
X.shape[0], X.shape[1]))
else:
serialize_mdspan[int32_t, matrix_extent[size_t], row_major](
Expand All @@ -119,8 +114,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[int32_t, matrix_extent[size_t],
row_major]&>
make_host_matrix_view[int32_t, size_t, row_major](
<int32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<int32_t *>buf_ptr,
X.shape[0], X.shape[1]))
elif X.dtype == np.uint32:
if fortran_order:
Expand All @@ -130,8 +124,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[uint32_t, matrix_extent[size_t],
col_major]&>
make_host_matrix_view[uint32_t, size_t, col_major](
<uint32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<uint32_t *>buf_ptr,
X.shape[0], X.shape[1]))
else:
serialize_mdspan[uint32_t, matrix_extent[size_t], row_major](
Expand All @@ -140,11 +133,12 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
<const host_mdspan[uint32_t, matrix_extent[size_t],
row_major]&>
make_host_matrix_view[uint32_t, size_t, row_major](
<uint32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
<PyObject *> X.data).buf,
<uint32_t *>buf_ptr,
X.shape[0], X.shape[1]))
else:
PyBuffer_Release(&buf)
raise NotImplementedError()
PyBuffer_Release(&buf)
f = io.BytesIO(oss.str())
X2 = np.load(f)
assert np.all(X.shape == X2.shape)
Expand Down
14 changes: 1 addition & 13 deletions python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,7 @@ cdef extern from "raft_runtime/neighbors/hnsw.hpp" \
host_matrix_view[uint64_t, int64_t, row_major] neighbors,
host_matrix_view[float, int64_t, row_major] distances) except +

cdef unique_ptr[index[float]] deserialize_file[float](
const device_resources& handle,
const string& filename,
int dim,
DistanceType metric) except +

cdef unique_ptr[index[int8_t]] deserialize_file[int8_t](
const device_resources& handle,
const string& filename,
int dim,
DistanceType metric) except +

cdef unique_ptr[index[uint8_t]] deserialize_file[uint8_t](
cdef unique_ptr[index[T]] deserialize_file[T](
const device_resources& handle,
const string& filename,
int dim,
Expand Down

0 comments on commit ef28628

Please sign in to comment.