Skip to content

Commit

Permalink
Adding lightweight cai_wrapper to reduce boilerplate (#1027)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1027
  • Loading branch information
cjnolet authored Nov 18, 2022
1 parent e06b156 commit f6e5226
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 81 deletions.
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
#


from .cai_wrapper import cai_wrapper
from .cuda import Stream
from .device_ndarray import device_ndarray
from .handle import Handle
73 changes: 73 additions & 0 deletions python/pylibraft/pylibraft/common/cai_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np

from pylibraft.common import input_validation


class cai_wrapper:
"""
Simple wrapper around a CUDA array interface object to reduce
boilerplate for extracting common information from the underlying
dictionary.
"""

def __init__(self, cai_arr):
"""
Constructor accepts a CUDA array interface compliant array
Parameters
----------
cai_arr : CUDA array interface array
"""
self.cai_ = cai_arr.__cuda_array_interface__

@property
def dtype(self):
"""
Returns the dtype of the underlying CUDA array interface
"""
return np.dtype(self.cai_["typestr"])

@property
def shape(self):
"""
Returns the shape of the underlying CUDA array interface
"""
return self.cai_["shape"]

@property
def c_contiguous(self):
"""
Returns whether the underlying CUDA array interface has
c-ordered (row-major) layout
"""
return input_validation.is_c_contiguous(self.cai_)

@property
def f_contiguous(self):
"""
Returns whether the underlying CUDA array interface has
f-ordered (column-major) layout
"""
return not input_validation.is_c_contiguous(self.cai_)

@property
def data(self):
"""
Returns the data pointer of the underlying CUDA array interface
"""
return self.cai_["data"][0]
39 changes: 17 additions & 22 deletions python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,10 @@ from libcpp cimport bool

from .distance_type cimport DistanceType

from pylibraft.common import Handle, device_ndarray
from pylibraft.common import Handle, cai_wrapper, device_ndarray
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


def is_c_cont(cai, dt):
return "strides" not in cai or \
cai["strides"] is None or \
cai["strides"][1] == dt.itemsize
from pylibraft.common.handle cimport handle_t


cdef extern from "raft_distance/fused_l2_min_arg.hpp" \
Expand Down Expand Up @@ -135,41 +130,41 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
"""

x_cai = X.__cuda_array_interface__
y_cai = Y.__cuda_array_interface__
x_cai = cai_wrapper(X)
y_cai = cai_wrapper(Y)

x_dt = np.dtype(x_cai["typestr"])
y_dt = np.dtype(y_cai["typestr"])
x_dt = x_cai.dtype
y_dt = y_cai.dtype

m = x_cai["shape"][0]
n = y_cai["shape"][0]
m = x_cai.shape[0]
n = y_cai.shape[0]

if out is None:
output = device_ndarray.empty((m,), dtype="int32")
else:
output = out

output_cai = output.__cuda_array_interface__
output_cai = cai_wrapper(output)

x_k = x_cai["shape"][1]
y_k = y_cai["shape"][1]
x_k = x_cai.shape[1]
y_k = y_cai.shape[1]

if x_k != y_k:
raise ValueError("Inputs must have same number of columns. "
"a=%s, b=%s" % (x_k, y_k))

x_ptr = <uintptr_t>x_cai["data"][0]
y_ptr = <uintptr_t>y_cai["data"][0]
x_ptr = <uintptr_t>x_cai.data
y_ptr = <uintptr_t>y_cai.data

d_ptr = <uintptr_t>output_cai["data"][0]
d_ptr = <uintptr_t>output_cai.data

handle = handle if handle is not None else Handle()
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

d_dt = np.dtype(output_cai["typestr"])
d_dt = output_cai.dtype

x_c_contiguous = is_c_cont(x_cai, x_dt)
y_c_contiguous = is_c_cont(y_cai, y_dt)
x_c_contiguous = x_cai.c_contiguous
y_c_contiguous = y_cai.c_contiguous

if x_c_contiguous != y_c_contiguous:
raise ValueError("Inputs must have matching strides")
Expand Down
38 changes: 16 additions & 22 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@ from pylibraft.common.handle import auto_sync_handle

from pylibraft.common.handle cimport handle_t

from pylibraft.common import device_ndarray


def is_c_cont(cai, dt):
return "strides" not in cai or \
cai["strides"] is None or \
cai["strides"][1] == dt.itemsize
from pylibraft.common import cai_wrapper, device_ndarray


cdef extern from "raft_distance/pairwise_distance.hpp" \
Expand Down Expand Up @@ -179,40 +173,40 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
"""

x_cai = X.__cuda_array_interface__
y_cai = Y.__cuda_array_interface__
x_cai = cai_wrapper(X)
y_cai = cai_wrapper(Y)

m = x_cai["shape"][0]
n = y_cai["shape"][0]
m = x_cai.shape[0]
n = y_cai.shape[0]

x_dt = np.dtype(x_cai["typestr"])
y_dt = np.dtype(y_cai["typestr"])
x_dt = x_cai.dtype
y_dt = y_cai.dtype

if out is None:
dists = device_ndarray.empty((m, n), dtype=y_dt)
else:
dists = out

x_k = x_cai["shape"][1]
y_k = y_cai["shape"][1]
x_k = x_cai.shape[1]
y_k = y_cai.shape[1]

dists_cai = dists.__cuda_array_interface__
dists_cai = cai_wrapper(dists)

if x_k != y_k:
raise ValueError("Inputs must have same number of columns. "
"a=%s, b=%s" % (x_k, y_k))

x_ptr = <uintptr_t>x_cai["data"][0]
y_ptr = <uintptr_t>y_cai["data"][0]
d_ptr = <uintptr_t>dists_cai["data"][0]
x_ptr = <uintptr_t>x_cai.data
y_ptr = <uintptr_t>y_cai.data
d_ptr = <uintptr_t>dists_cai.data

handle = handle if handle is not None else Handle()
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

d_dt = np.dtype(dists_cai["typestr"])
d_dt = dists_cai.dtype

x_c_contiguous = is_c_cont(x_cai, x_dt)
y_c_contiguous = is_c_cont(y_cai, y_dt)
x_c_contiguous = x_cai.c_contiguous
y_c_contiguous = y_cai.c_contiguous

if x_c_contiguous != y_c_contiguous:
raise ValueError("Inputs must have matching strides")
Expand Down
56 changes: 28 additions & 28 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ from libcpp cimport bool, nullptr

from pylibraft.distance.distance_type cimport DistanceType

from pylibraft.common import Handle, device_ndarray
from pylibraft.common import Handle, cai_wrapper, device_ndarray
from pylibraft.common.interruptible import cuda_interruptible

from pylibraft.common.handle cimport handle_t
Expand Down Expand Up @@ -88,19 +88,19 @@ cdef _get_dtype_string(dtype):


def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None):
if cai["typestr"] not in exp_dt:
if cai.dtype not in exp_dt:
raise TypeError("dtype %s not supported" % cai["typestr"])

if not is_c_contiguous(cai):
if not cai.c_contiguous:
raise ValueError("Row major input is expected")

if exp_cols is not None and cai["shape"][1] != exp_cols:
if exp_cols is not None and cai.shape[1] != exp_cols:
raise ValueError("Incorrect number of columns, expected {} got {}"
.format(exp_cols, cai["shape"][1]))
.format(exp_cols, cai.shape[1]))

if exp_rows is not None and cai["shape"][0] != exp_rows:
if exp_rows is not None and cai.shape[0] != exp_rows:
raise ValueError("Incorrect number of rows, expected {} , got {}"
.format(exp_rows, cai["shape"][0]))
.format(exp_rows, cai.shape[0]))


cdef class IndexParams:
Expand Down Expand Up @@ -352,14 +352,14 @@ def build(IndexParams index_params, dataset, handle=None):
handle.sync()
"""
dataset_cai = dataset.__cuda_array_interface__
dataset_dt = np.dtype(dataset_cai["typestr"])
dataset_cai = cai_wrapper(dataset)
dataset_dt = dataset_cai.dtype
_check_input_array(dataset_cai, [np.dtype('float32'), np.dtype('byte'),
np.dtype('ubyte')])
cdef uintptr_t dataset_ptr = dataset_cai["data"][0]
cdef uintptr_t dataset_ptr = dataset_cai.data

cdef uint64_t n_rows = dataset_cai["shape"][0]
cdef uint32_t dim = dataset_cai["shape"][1]
cdef uint64_t n_rows = dataset_cai.shape[0]
cdef uint32_t dim = dataset_cai.shape[1]

if handle is None:
handle = Handle()
Expand Down Expand Up @@ -467,22 +467,22 @@ def extend(Index index, new_vectors, new_indices, handle=None):
handle = Handle()
cdef handle_t* handle_ = <handle_t*><size_t>handle.getHandle()

vecs_cai = new_vectors.__cuda_array_interface__
vecs_dt = np.dtype(vecs_cai["typestr"])
cdef uint64_t n_rows = vecs_cai["shape"][0]
cdef uint32_t dim = vecs_cai["shape"][1]
vecs_cai = cai_wrapper(new_vectors)
vecs_dt = vecs_cai.dtype
cdef uint64_t n_rows = vecs_cai.shape[0]
cdef uint32_t dim = vecs_cai.shape[1]

_check_input_array(vecs_cai, [np.dtype('float32'), np.dtype('byte'),
np.dtype('ubyte')],
exp_cols=index.dim)

idx_cai = new_indices.__cuda_array_interface__
idx_cai = cai_wrapper(new_indices)
_check_input_array(idx_cai, [np.dtype('uint64')], exp_rows=n_rows)
if len(idx_cai["shape"])!=1:
if len(idx_cai.shape)!=1:
raise ValueError("Indices array is expected to be 1D")

cdef uintptr_t vecs_ptr = vecs_cai["data"][0]
cdef uintptr_t idx_ptr = idx_cai["data"][0]
cdef uintptr_t vecs_ptr = vecs_cai.data
cdef uintptr_t idx_ptr = idx_cai.data

if vecs_dt == np.float32:
with cuda_interruptible():
Expand Down Expand Up @@ -656,9 +656,9 @@ def search(SearchParams search_params,
handle = Handle()
cdef handle_t* handle_ = <handle_t*><size_t>handle.getHandle()

queries_cai = queries.__cuda_array_interface__
queries_dt = np.dtype(queries_cai["typestr"])
cdef uint32_t n_queries = queries_cai["shape"][0]
queries_cai = cai_wrapper(queries)
queries_dt = queries_cai.dtype
cdef uint32_t n_queries = queries_cai.shape[0]

_check_input_array(queries_cai, [np.dtype('float32'), np.dtype('byte'),
np.dtype('ubyte')],
Expand All @@ -667,22 +667,22 @@ def search(SearchParams search_params,
if neighbors is None:
neighbors = device_ndarray.empty((n_queries, k), dtype='uint64')

neighbors_cai = neighbors.__cuda_array_interface__
neighbors_cai = cai_wrapper(neighbors)
_check_input_array(neighbors_cai, [np.dtype('uint64')],
exp_rows=n_queries, exp_cols=k)

if distances is None:
distances = device_ndarray.empty((n_queries, k), dtype='float32')

distances_cai = distances.__cuda_array_interface__
distances_cai = cai_wrapper(distances)
_check_input_array(distances_cai, [np.dtype('float32')],
exp_rows=n_queries, exp_cols=k)

cdef c_ivf_pq.search_params params = search_params.params

cdef uintptr_t queries_ptr = queries_cai["data"][0]
cdef uintptr_t neighbors_ptr = neighbors_cai["data"][0]
cdef uintptr_t distances_ptr = distances_cai["data"][0]
cdef uintptr_t queries_ptr = queries_cai.data
cdef uintptr_t neighbors_ptr = neighbors_cai.data
cdef uintptr_t distances_ptr = distances_cai.data
# TODO(tfeher) pass mr_ptr arg
cdef device_memory_resource* mr_ptr = <device_memory_resource*> nullptr
if memory_resource is not None:
Expand Down
16 changes: 8 additions & 8 deletions python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import numpy as np
from cython.operator cimport dereference as deref
from libc.stdint cimport int64_t, uintptr_t

from pylibraft.common import Handle
from pylibraft.common import Handle, cai_wrapper
from pylibraft.common.handle import auto_sync_handle

from libcpp cimport bool
Expand Down Expand Up @@ -129,14 +129,14 @@ def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None):
if out is None:
raise Exception("'out' cannot be None!")

out_cai = out.__cuda_array_interface__
theta_cai = theta.__cuda_array_interface__
out_cai = cai_wrapper(out)
theta_cai = cai_wrapper(theta)

n_edges = out_cai["shape"][0]
out_ptr = <uintptr_t>out_cai["data"][0]
theta_ptr = <uintptr_t>theta_cai["data"][0]
out_dt = np.dtype(out_cai["typestr"])
theta_dt = np.dtype(theta_cai["typestr"])
n_edges = out_cai.shape[0]
out_ptr = <uintptr_t>out_cai.data
theta_ptr = <uintptr_t>theta_cai.data
out_dt = out_cai.dtype
theta_dt = theta_cai.dtype

cdef RngState *rng = new RngState(seed)

Expand Down
Loading

0 comments on commit f6e5226

Please sign in to comment.