Skip to content

Commit

Permalink
Part of Bears-R-Us#3708: array_api to call functions from arkouda.pda…
Browse files Browse the repository at this point in the history
…rray_creation
  • Loading branch information
ajpotts committed Oct 25, 2024
1 parent e35b204 commit 14a37e9
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 158 deletions.
79 changes: 16 additions & 63 deletions arkouda/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast

from arkouda.client import generic_msg
import numpy as np
from arkouda.pdarrayclass import create_pdarray, pdarray, _to_pdarray
from arkouda.pdarraycreation import scalar_array

from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import resolve_scalar_dtype
from arkouda.pdarrayclass import _to_pdarray, pdarray

if TYPE_CHECKING:
from ._typing import (
Expand All @@ -17,6 +16,7 @@
NestedSequence,
SupportsBufferProtocol,
)

import arkouda as ak


Expand Down Expand Up @@ -83,9 +83,7 @@ def asarray(
elif isinstance(obj, np.ndarray):
return Array._new(_to_pdarray(obj, dt=dtype))
else:
raise ValueError(
"asarray not implemented for 'NestedSequence' or 'SupportsBufferProtocol'"
)
raise ValueError("asarray not implemented for 'NestedSequence' or 'SupportsBufferProtocol'")


def arange(
Expand Down Expand Up @@ -155,9 +153,7 @@ def empty(
)


def empty_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, without initializing entries.
"""
Expand Down Expand Up @@ -217,17 +213,8 @@ def eye(
if n_cols is not None:
cols = n_cols

repMsg = generic_msg(
cmd="eye",
args={
"dtype": np.dtype(dtype).name,
"rows": n_rows,
"cols": cols,
"diag": k,
},
)

return Array._new(create_pdarray(repMsg))
from arkouda import dtype as akdtype
return Array._new(ak.eye(rows=n_rows, cols=cols, diag=k, dt=akdtype(dtype)))


def from_dlpack(x: object, /) -> Array:
Expand Down Expand Up @@ -312,9 +299,7 @@ def ones(
return a


def ones_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, filled with ones.
"""
Expand All @@ -328,33 +313,18 @@ def tril(x: Array, /, *, k: int = 0) -> Array:
"""
from .array_object import Array

repMsg = generic_msg(
cmd=f"tril{x._array.ndim}D",
args={
"array": x._array.name,
"diag": k,
},
)

return Array._new(create_pdarray(repMsg))
return Array._new(ak.tril(x._array, diag=k))


def triu(x: Array, /, *, k: int = 0) -> Array:
"""
Create a new array with the values from `x` above the `k`-th diagonal, and
all other elements zero.
"""
from .array_object import Array

repMsg = generic_msg(
cmd=f"triu{x._array.ndim}D",
args={
"array": x._array.name,
"diag": k,
},
)
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak.triu(x._array, k))


def zeros(
Expand All @@ -372,31 +342,14 @@ def zeros(
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")

if isinstance(shape, tuple):
if shape == ():
return Array._new(scalar_array(0, dtype=dtype))
else:
ndim = len(shape)
else:
if shape == 0:
return Array._new(scalar_array(0, dtype=dtype))
else:
ndim = 1

dtype = akdtype(dtype) # normalize dtype
dtype_name = cast(np.dtype, dtype).name
return_dtype = akdtype(dtype)
if dtype is None:
return_dtype = akdtype(ak.float64)

repMsg = generic_msg(
cmd=f"create<{dtype_name},{ndim}>",
args={"shape": shape},
)
return Array._new(ak.zeros(shape, return_dtype))

return Array._new(create_pdarray(repMsg))


def zeros_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, filled with zeros.
"""
Expand Down
24 changes: 4 additions & 20 deletions arkouda/array_api/elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise AND of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -141,10 +138,7 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise OR of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand All @@ -169,10 +163,7 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise XOR of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -410,14 +401,7 @@ def logical_not(x: Array, /) -> Array:
"""
Compute the element-wise logical NOT of a boolean array.
"""
repMsg = ak.generic_msg(
cmd=f"efunc{x._array.ndim}D",
args={
"func": "not",
"array": x._array,
},
)
return Array._new(ak.create_pdarray(repMsg))
return ~x


def logical_or(x1: Array, x2: Array, /) -> Array:
Expand Down
65 changes: 9 additions & 56 deletions arkouda/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
from .array_object import Array

from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray, broadcast_if_needed


def matmul(x1: Array, x2: Array, /) -> Array:
"""
Matrix product of two arrays.
"""
from .array_object import Array

if x1._array.ndim < 2 and x2._array.ndim < 2:
raise ValueError(
"matmul requires at least one array argument to have more than two dimensions"
)

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1._array, x2._array)
from arkouda import matmul as ak_matmul

repMsg = generic_msg(
cmd=f"matMul{len(x1b.shape)}D",
args={
"x1": x1b.name,
"x2": x2b.name,
},
)

if tmp_x1:
del x1b
if tmp_x2:
del x2b
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak_matmul(x1._array, x2._array))


def tensordot():
Expand All @@ -44,42 +23,16 @@ def matrix_transpose(x: Array) -> Array:
"""
Matrix product of two arrays.
"""
from .array_object import Array
from arkouda import transpose as ak_transpose

if x._array.ndim < 2:
raise ValueError(
"matrix_transpose requires the array to have more than two dimensions"
)

repMsg = generic_msg(
cmd=f"transpose{x._array.ndim}D",
args={
"array": x._array.name,
},
)

return Array._new(create_pdarray(repMsg))


def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
from .array_object import Array

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1._array, x2._array)
return Array._new(ak_transpose(x._array))

repMsg = generic_msg(
cmd=f"vecdot{len(x1b.shape)}D",
args={
"x1": x1b.name,
"x2": x2b.name,
"bcShape": x1b.shape,
"axis": axis,
},
)

if tmp_x1:
del x1b
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
from arkouda import vecdot as ak_vecdot

if tmp_x2:
del x2b
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak_vecdot(x1._array, x2._array))
57 changes: 40 additions & 17 deletions arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, List, Sequence, Tuple, TypeVar, Union
from typing import cast as type_cast
from typing import no_type_check
from arkouda.groupbyclass import groupable

import numpy as np
from typeguard import typechecked

from arkouda.client import generic_msg
from arkouda.dtypes import str_ as akstr_
from arkouda.groupbyclass import GroupBy
from arkouda.groupbyclass import GroupBy, groupable
from arkouda.numpy.dtypes import DTypes, bigint
from arkouda.numpy.dtypes import bool_ as ak_bool
from arkouda.numpy.dtypes import dtype as akdtype
Expand All @@ -26,7 +26,13 @@
from arkouda.numpy.dtypes import _datatype_check
from arkouda.pdarrayclass import all as ak_all
from arkouda.pdarrayclass import any as ak_any
from arkouda.pdarrayclass import argmax, create_pdarray, pdarray, sum
from arkouda.pdarrayclass import (
argmax,
broadcast_if_needed,
create_pdarray,
pdarray,
sum,
)
from arkouda.pdarraycreation import array, linspace, scalar_array
from arkouda.sorting import sort
from arkouda.strings import Strings
Expand Down Expand Up @@ -2593,18 +2599,26 @@ def matmul(pdaLeft: pdarray, pdaRight: pdarray):
"""
if pdaLeft.ndim != pdaRight.ndim:
raise ValueError("matmul requires matrices of matching rank.")

x1, x2, tmp_x1, tmp_x2 = broadcast_if_needed(pdaLeft, pdaRight)

cmd = f"matmul<{pdaLeft.dtype},{pdaRight.dtype},{pdaLeft.ndim}>"
args = {
"x1": pdaLeft,
"x2": pdaRight,
"x1": x1,
"x2": x2,
}
return create_pdarray(
generic_msg(
cmd=cmd,
args=args,
)
repMsg = generic_msg(
cmd=cmd,
args=args,
)

if tmp_x1:
del x1
if tmp_x2:
del x2

return create_pdarray(repMsg)


def vecdot(x1: pdarray, x2: pdarray):
"""
Expand Down Expand Up @@ -2641,16 +2655,25 @@ def vecdot(x1: pdarray, x2: pdarray):
raise ValueError("vecdot requires matrices of matching rank.")
if x1.ndim < 2:
raise ValueError("vector requires matrices of rank 2 or more.")

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1, x2)

cmd = f"vecdot<{x1.dtype},{x2.dtype},{x1.ndim}>"
args = {
"x1": x1,
"x2": x2,
"x1": x1b,
"x2": x2b,
"bcShape": tuple(x1.shape),
"axis": 0,
}
return create_pdarray(
generic_msg(
cmd=cmd,
args=args,
)

repMsg = generic_msg(
cmd=cmd,
args=args,
)

if tmp_x1:
del x1
if tmp_x2:
del x2

return create_pdarray(repMsg)
3 changes: 3 additions & 0 deletions arkouda/pdarraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,9 @@ def zeros(
if ndim > get_max_array_rank():
raise ValueError(f"array rank {ndim} exceeds maximum of {get_max_array_rank()}")

if shape == ():
return scalar_array(0, dtype=dtype)

repMsg = generic_msg(cmd=f"create<{dtype_name},{ndim}>", args={"shape": shape})

return create_pdarray(repMsg, max_bits=max_bits)
Expand Down
Loading

0 comments on commit 14a37e9

Please sign in to comment.