Skip to content

Commit

Permalink
Part of Bears-R-Us#3708: array_api to call functions from arkouda.pda…
Browse files Browse the repository at this point in the history
…rray_creation
  • Loading branch information
ajpotts committed Sep 10, 2024
1 parent 4db3469 commit acd624f
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 134 deletions.
78 changes: 15 additions & 63 deletions arkouda/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast

from arkouda.client import generic_msg
import numpy as np
from arkouda.pdarrayclass import create_pdarray, pdarray, _to_pdarray
from arkouda.pdarraycreation import scalar_array

from arkouda.client import generic_msg
from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import resolve_scalar_dtype
from arkouda.pdarrayclass import _to_pdarray, create_pdarray, pdarray

if TYPE_CHECKING:
from ._typing import (
Expand All @@ -17,6 +17,7 @@
NestedSequence,
SupportsBufferProtocol,
)

import arkouda as ak


Expand Down Expand Up @@ -83,9 +84,7 @@ def asarray(
elif isinstance(obj, np.ndarray):
return Array._new(_to_pdarray(obj, dt=dtype))
else:
raise ValueError(
"asarray not implemented for 'NestedSequence' or 'SupportsBufferProtocol'"
)
raise ValueError("asarray not implemented for 'NestedSequence' or 'SupportsBufferProtocol'")


def arange(
Expand Down Expand Up @@ -155,9 +154,7 @@ def empty(
)


def empty_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, without initializing entries.
"""
Expand Down Expand Up @@ -217,17 +214,7 @@ def eye(
if n_cols is not None:
cols = n_cols

repMsg = generic_msg(
cmd="eye",
args={
"dtype": np.dtype(dtype).name,
"rows": n_rows,
"cols": cols,
"diag": k,
},
)

return Array._new(create_pdarray(repMsg))
return Array._new(ak.eye(rows=n_rows, cols=cols, diag=k, dt=np.dtype(dtype).name))


def from_dlpack(x: object, /) -> Array:
Expand Down Expand Up @@ -312,9 +299,7 @@ def ones(
return a


def ones_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, filled with ones.
"""
Expand All @@ -328,33 +313,18 @@ def tril(x: Array, /, *, k: int = 0) -> Array:
"""
from .array_object import Array

repMsg = generic_msg(
cmd=f"tril{x._array.ndim}D",
args={
"array": x._array.name,
"diag": k,
},
)

return Array._new(create_pdarray(repMsg))
return Array._new(ak.tril(x, diag=k))


def triu(x: Array, /, *, k: int = 0) -> Array:
"""
Create a new array with the values from `x` above the `k`-th diagonal, and
all other elements zero.
"""
from .array_object import Array

repMsg = generic_msg(
cmd=f"triu{x._array.ndim}D",
args={
"array": x._array.name,
"diag": k,
},
)
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak.triu(x, k))


def zeros(
Expand All @@ -372,31 +342,13 @@ def zeros(
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")

if isinstance(shape, tuple):
if shape == ():
return Array._new(scalar_array(0, dtype=dtype))
else:
ndim = len(shape)
else:
if shape == 0:
return Array._new(scalar_array(0, dtype=dtype))
else:
ndim = 1

dtype = akdtype(dtype) # normalize dtype
dtype_name = cast(np.dtype, dtype).name
if dtype is None:
dtype = ak.float64

repMsg = generic_msg(
cmd=f"create<{dtype_name},{ndim}>",
args={"shape": shape},
)
return Array._new(ak.zeros(shape, dtype))

return Array._new(create_pdarray(repMsg))


def zeros_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, filled with zeros.
"""
Expand Down
24 changes: 4 additions & 20 deletions arkouda/array_api/elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise AND of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -141,10 +138,7 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise OR of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand All @@ -169,10 +163,7 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise XOR of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -410,14 +401,7 @@ def logical_not(x: Array, /) -> Array:
"""
Compute the element-wise logical NOT of a boolean array.
"""
repMsg = ak.generic_msg(
cmd=f"efunc{x._array.ndim}D",
args={
"func": "not",
"array": x._array,
},
)
return Array._new(ak.create_pdarray(repMsg))
return ~x


def logical_or(x1: Array, x2: Array, /) -> Array:
Expand Down
56 changes: 6 additions & 50 deletions arkouda/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,9 @@ def matmul(x1: Array, x2: Array, /) -> Array:
"""
from .array_object import Array

if x1._array.ndim < 2 and x2._array.ndim < 2:
raise ValueError(
"matmul requires at least one array argument to have more than two dimensions"
)
from arkouda import matmul as ak_matmul

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1._array, x2._array)

repMsg = generic_msg(
cmd=f"matMul{len(x1b.shape)}D",
args={
"x1": x1b.name,
"x2": x2b.name,
},
)

if tmp_x1:
del x1b
if tmp_x2:
del x2b

return Array._new(create_pdarray(repMsg))
return Array._new(ak_matmul(x1._array, x2._array))


def tensordot():
Expand All @@ -46,40 +28,14 @@ def matrix_transpose(x: Array) -> Array:
"""
from .array_object import Array

if x._array.ndim < 2:
raise ValueError(
"matrix_transpose requires the array to have more than two dimensions"
)

repMsg = generic_msg(
cmd=f"transpose{x._array.ndim}D",
args={
"array": x._array.name,
},
)
from arkouda import transpose as ak_transpose

return Array._new(create_pdarray(repMsg))
return Array._new(ak_transpose(x._array))


def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
from .array_object import Array

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1._array, x2._array)

repMsg = generic_msg(
cmd=f"vecdot{len(x1b.shape)}D",
args={
"x1": x1b.name,
"x2": x2b.name,
"bcShape": x1b.shape,
"axis": axis,
},
)

if tmp_x1:
del x1b

if tmp_x2:
del x2b
from arkouda import vecdot as ak_vecdot

return Array._new(create_pdarray(repMsg))
return Array._new(ak_vecdot(x1._array, x2._array))
8 changes: 7 additions & 1 deletion arkouda/pdarraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from arkouda.pdarrayclass import create_pdarray, pdarray
from arkouda.strings import Strings


__all__ = [
"array",
"zeros",
Expand Down Expand Up @@ -284,6 +283,7 @@ def array(
raise RuntimeError(f"Unhandled dtype {a.dtype}")
else:
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(a.shape)

# Do not allow arrays that are too large
Expand Down Expand Up @@ -478,11 +478,15 @@ def zeros(
if dtype_name not in NumericDTypes:
raise TypeError(f"unsupported dtype {dtype}")
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(size)

if ndim > get_max_array_rank():
raise ValueError(f"array rank {ndim} exceeds maximum of {get_max_array_rank()}")

if shape == ():
return scalar_array(0, dtype=dtype)

repMsg = generic_msg(cmd=f"create<{dtype_name},{ndim}>", args={"shape": shape})

return create_pdarray(repMsg, max_bits=max_bits)
Expand Down Expand Up @@ -538,6 +542,7 @@ def ones(
if dtype_name not in NumericDTypes:
raise TypeError(f"unsupported dtype {dtype}")
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(size)

if ndim > get_max_array_rank():
Expand Down Expand Up @@ -607,6 +612,7 @@ def full(
if dtype_name not in NumericDTypes:
raise TypeError(f"unsupported dtype {dtype}")
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(size)

if ndim > get_max_array_rank():
Expand Down
58 changes: 58 additions & 0 deletions tests/array_api/array_creation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from math import sqrt

import numpy as np
import pytest

import arkouda as ak
import arkouda.array_api as xp
from arkouda.testing import assert_almost_equivalent

# requires the server to be built with 2D array support
SHAPES = [(), (0,), (0, 0), (1,), (5,), (2, 2), (5, 10)]
Expand Down Expand Up @@ -44,3 +47,58 @@ def test_from_numpy(self):
assert b.ndim == a.ndim
assert b.shape == a.shape
assert b.tolist() == a.tolist()

@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.parametrize("data_type", [ak.int64, ak.float64, ak.bool_])
@pytest.mark.parametrize("prob_size", pytest.prob_size)
def test_triu(self, data_type, prob_size):
from arkouda.array_api.creation_functions import triu as array_triu

size = int(sqrt(prob_size))

# test on one square and two non-square matrices

for rows, cols in [(size, size), (size + 1, size - 1), (size - 1, size + 1)]:
pda = ak.randint(1, 10, (rows, cols))
nda = pda.to_ndarray()
sweep = range(-(rows - 1), cols - 1) # sweeps the diagonal from LL to UR
for diag in sweep:
np_triu = np.triu(nda, diag)
ak_triu = array_triu(pda, k=diag)._array
assert_almost_equivalent(ak_triu, np_triu)

@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.parametrize("data_type", [ak.int64, ak.float64, ak.bool_])
@pytest.mark.parametrize("prob_size", pytest.prob_size)
def test_tril(self, data_type, prob_size):
from arkouda.array_api.creation_functions import tril as array_tril

size = int(sqrt(prob_size))

# test on one square and two non-square matrices

for rows, cols in [(size, size), (size + 1, size - 1), (size - 1, size + 1)]:
pda = ak.randint(1, 10, (rows, cols))
nda = pda.to_ndarray()
sweep = range(-(rows - 2), cols) # sweeps the diagonal from LL to UR
for diag in sweep:
np_tril = np.tril(nda, diag)
ak_tril = array_tril(pda, k=diag)._array
assert_almost_equivalent(np_tril, ak_tril)

@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.parametrize("data_type", [ak.int64, ak.float64, ak.bool_])
@pytest.mark.parametrize("prob_size", pytest.prob_size)
def test_eye(self, data_type, prob_size):
from arkouda.array_api.creation_functions import eye as array_eye

size = int(sqrt(prob_size))

# test on one square and two non-square matrices

for rows, cols in [(size, size), (size + 1, size - 1), (size - 1, size + 1)]:
sweep = range(-(cols - 1), rows) # sweeps the diagonal from LL to UR
for diag in sweep:
np_eye = np.eye(rows, cols, diag, dtype=data_type)
ak_eye = array_eye(rows, cols, k=diag, dtype=data_type)._array
assert_almost_equivalent(np_eye, ak_eye)
Loading

0 comments on commit acd624f

Please sign in to comment.