diff --git a/ucp/_libs/arr.pxd b/ucp/_libs/arr.pxd index baa6a98e..16367cb8 100644 --- a/ucp/_libs/arr.pxd +++ b/ucp/_libs/arr.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. # cython: language_level=3 @@ -24,3 +24,6 @@ cdef class Array: cpdef bint _f_contiguous(self) cpdef bint _contiguous(self) cpdef Py_ssize_t _nbytes(self) + + +cpdef Array asarray(obj) diff --git a/ucp/_libs/arr.pyi b/ucp/_libs/arr.pyi index 2053179e..d2ba4798 100644 --- a/ucp/_libs/arr.pyi +++ b/ucp/_libs/arr.pyi @@ -1,7 +1,12 @@ -from typing import Tuple +# Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. -class Array: - def __init__(self, obj: object): ... +from typing import Generic, Tuple, TypeVar + +T = TypeVar("T") + +class Array(Generic[T]): + def __init__(self, obj: T): ... @property def c_contiguous(self) -> bool: ... @property @@ -14,3 +19,9 @@ class Array: def shape(self) -> Tuple[int]: ... @property def strides(self) -> Tuple[int]: ... + @property + def cuda(self) -> bool: ... + @property + def obj(self) -> T: ... + +def asarray(obj) -> Array: ... diff --git a/ucp/_libs/arr.pyx b/ucp/_libs/arr.pyx index 1937e7bb..0c332ef3 100644 --- a/ucp/_libs/arr.pyx +++ b/ucp/_libs/arr.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. # cython: language_level=3 @@ -297,3 +297,22 @@ cdef inline Py_ssize_t _nbytes(Py_ssize_t itemsize, for i in range(ndim): nbytes *= shape_mv[i] return nbytes + + +cpdef Array asarray(obj): + """Coerce other objects to ``Array``. No-op for existing ``Array``s. + + Parameters + ---------- + obj: object + Object exposing the Python buffer protocol or ``__cuda_array_interface__``. + + Returns + ------- + array: Array + An instance of the ``Array`` class. + """ + if isinstance(obj, Array): + return obj + else: + return Array(obj)