Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support where-clause evaluation in registration annotations #3841

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions arkouda/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def clip(a: Array, a_min, a_max, /) -> Array:
a_max : scalar
The maximum value
"""
if a.dtype == ak.bigint or a.dtype == ak.bool_:
raise RuntimeError(f"Error executing command: clip does not support dtype {a.dtype}")

return Array._new(
create_pdarray(
generic_msg(
Expand Down Expand Up @@ -99,6 +102,9 @@ def diff(a: Array, /, n: int = 1, axis: int = -1, prepend=None, append=None) ->
append : Array, optional
Array to append to `a` along `axis` before calculating the difference.
"""
if a.dtype == ak.bigint or a.dtype == ak.bool_:
raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}")

if prepend is not None and append is not None:
a_ = concat((prepend, a, append), axis=axis)
elif prepend is not None:
Expand Down Expand Up @@ -146,6 +152,9 @@ def pad(
if mode != "constant":
raise NotImplementedError(f"pad mode '{mode}' is not supported")

if array.dtype == ak.bigint:
raise RuntimeError("Error executing command: pad does not support dtype bigint")

if "constant_values" not in kwargs:
cvals = 0
else:
Expand Down
43 changes: 17 additions & 26 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2720,16 +2720,6 @@ def is_sorted(pda: pdarray) -> np.bool_:
)


def _get_axis_pdarray(axis: Optional[Union[int, Tuple[int, ...]]] = None):
from arkouda import array as ak_array

axis_list = []
if axis is not None:
axis_list = list(axis) if isinstance(axis, tuple) else [axis]

return ak_array(axis_list, dtype="int64")


@typechecked
def sum(
pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None
Expand Down Expand Up @@ -2757,12 +2747,13 @@ def sum(
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"sum<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"sum<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
# TODO: remove call to 'flatten'
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down Expand Up @@ -2845,12 +2836,12 @@ def prod(pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Un
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"prod<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"prod<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down Expand Up @@ -2882,12 +2873,12 @@ def min(
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"min<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"min<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down Expand Up @@ -2920,12 +2911,12 @@ def max(
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"max<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"max<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down
8 changes: 1 addition & 7 deletions src/ArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,11 @@ module ArgSortMsg
axis = msgArgs["axis"].toScalar(int),
symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype);

const iv = argsortDefault(vals, algorithm=algorithm, axis);
return st.insert(new shared SymEntry(iv));
}

proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string));
}

proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const name = msgArgs["name"].toScalar(string),
strings = getSegString(name, st),
Expand Down
36 changes: 27 additions & 9 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ module AryUtil
return (true, ret);
}

proc validateNegativeAxes(axes: list(int), param nd: int): (bool, list(int)) {
var ret = new list(int);
for a in axes {
if a >= 0 && a < nd {
ret.pushBack(a);
} else if a < 0 && a >= -nd {
ret.pushBack(nd + a);
} else {
return (false, ret);
}
}
return (true, ret);
}

/*
Get a domain that selects out the idx'th set of indices along the specified axes

Expand Down Expand Up @@ -328,6 +342,16 @@ module AryUtil
return ret;
}

proc reducedShape(shape: ?N*int, axes: list(int)): N*int {
var ret: N*int;
for param i in 0..<N {
if N == 1 || axes.size == 0 || axes.contains(i)
then ret[i] = 1;
else ret[i] = shape[i];
}
return ret;
}

/*
Returns stats on a given array in form (int,int,real,real,real).

Expand Down Expand Up @@ -947,9 +971,9 @@ module AryUtil
flatten a multi-dimensional array into a 1D array
*/
@arkouda.registerCommand
proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank > 1
{
proc flatten(const ref a: [?d] ?t): [] t throws {
if a.rank == 1 then return a;

var flat = makeDistArray(d.size, t);

// ranges of flat indices owned by each locale
Expand Down Expand Up @@ -1006,12 +1030,6 @@ module AryUtil
return flat;
}

proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank == 1
{
return a;
}

// helper for computing an array element's index from its order
record orderer {
param rank: int;
Expand Down
22 changes: 1 addition & 21 deletions src/CastMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ module CastMsg {
private config const logChannel = ServerConfig.logChannel;
const castLogger = new Logger(logLevel, logChannel);

proc isFloatingType(type t) param : bool {
return isRealType(t) || isImagType(t) || isComplexType(t);
}

@arkouda.instantiateAndRegister(prefix="cast")
proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_from,
type array_dtype_to,
param array_nd: int
): MsgTuple throws
where !(isFloatingType(array_dtype_from) && array_dtype_to == bigint) &&
where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) && array_dtype_to == bigint) &&
!(array_dtype_from == bigint && array_dtype_to == bool)
{
const a = st[msgArgs["name"]]: SymEntry(array_dtype_from, array_nd);
Expand All @@ -40,22 +36,6 @@ module CastMsg {
}
}

// cannot cast float types to bigint, cannot cast bigint to bool
proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_from,
type array_dtype_to,
param array_nd: int
): MsgTuple throws
where (isFloatingType(array_dtype_from) && array_dtype_to == bigint) ||
(array_dtype_from == bigint && array_dtype_to == bool)
{
return MsgTuple.error(
"cannot cast array of type %s to %s".format(
type2str(array_dtype_from),
type2str(array_dtype_to)
));
}

@arkouda.instantiateAndRegister(prefix="castToStrings")
proc castArrayToStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws {
const name = msgArgs["name"].toScalar(string);
Expand Down
12 changes: 0 additions & 12 deletions src/GenSymIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ module GenSymIO {
return st.insert(new shared SymEntry(makeArrayFromBytes(msgArgs.payload, shape, array_dtype)));
}

proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("Array creation from binary payload is not supported for bigint arrays");
}

proc makeArrayFromBytes(ref payload: bytes, shape: ?N*int, type t): [] t throws {
var size = 1;
for s in shape do size *= s;
Expand Down Expand Up @@ -138,12 +132,6 @@ module GenSymIO {
return MsgTuple.payload(bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size));
}

proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("cannot create ndarray from bigint array");
}

/*
* Utility proc to test casting a string to a specified type
* :arg c: String to cast
Expand Down
14 changes: 0 additions & 14 deletions src/IndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,6 @@ module IndexingMsg
}
}

proc multiPDArrayIndex(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_a, type array_dtype_idx, param array_nd: int): MsgTuple throws
where array_dtype_idx != int && array_dtype_idx != uint
{
return MsgTuple.error("Invalid index type: %s; must be 'int' or 'uint'".format(type2str(array_dtype_idx)));
}

private proc multiIndexShape(inShape: ?N*int, idxDims: [?d] int, outSize: int): (bool, int, N*int) {
var minShape: N*int = inShape,
firstRank = -1;
Expand Down Expand Up @@ -960,14 +954,6 @@ module IndexingMsg
return st.insert(new shared SymEntry(y, x.max_bits));
}

proc takeAlongAxis(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_x,
type array_dtype_idx,
param array_nd: int
): MsgTuple throws {
return MsgTuple.error("Cannot take along axis with non-integer index array");
}

use CommandMap;
registerFunction("arrayViewMixedIndex", arrayViewMixedIndexMsg, getModuleName());
registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName());
Expand Down
53 changes: 3 additions & 50 deletions src/LinalgMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ module LinalgMsg {
return st.insert(e);
}


proc eye(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("eye does not support the bigint dtype");
}

// tril and triu are identical except for the argument they pass to triluHandler (true for upper, false for lower)
// The zeros are written into the upper (or lower) triangle of the array, offset by the value of diag.

Expand All @@ -79,11 +72,6 @@ module LinalgMsg {
return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, false);
}

proc tril(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Array must be at least 2 dimensional for 'tril'");
}

// Create an array from an existing array with its lower triangle zeroed out

@arkouda.instantiateAndRegister
Expand All @@ -92,13 +80,9 @@ module LinalgMsg {
return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, true);
}

proc triu(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Array must be at least 2 dimensional for 'triu'");
}

// Fetch the arguments, call zeroTri, return result.

// Fetch the arguments, call zeroTri, return result.
// TODO: support instantiating param bools with 'true' and 'false' s.t. we'd have 'triluHandler<true>' and 'triluHandler<false>'
// cmds if this procedure were annotated instead of the two above.
proc triluHandler(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype, param array_nd: int, param upper: bool
): MsgTuple throws {
Expand Down Expand Up @@ -195,16 +179,6 @@ module LinalgMsg {

}

proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_nd < 2) && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) {
return MsgTuple.error("Matrix multiplication with arrays of dimension < 2 is not supported");
}

proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) {
return MsgTuple.error("Matrix multiplication with arrays of bigint type is not supported");
}

proc compute_result_type_matmul(type t1, type t2) type {
if t1 == real || t2 == real then return real;
if t1 == int || t2 == int then return int;
Expand Down Expand Up @@ -302,11 +276,6 @@ module LinalgMsg {
return ret;
}

proc transpose(array: [?d] ?t): [d] t throws
where d.rank < 2 {
throw new Error("Matrix transpose with arrays of dimension < 2 is not supported");
}

/*
Compute the generalized dot product of two tensors along the specified axis.

Expand Down Expand Up @@ -366,22 +335,6 @@ module LinalgMsg {
return bool;
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_nd < 2) && ((array_dtype_x1 != bool) || (array_dtype_x2 != bool))
&& (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) {
return MsgTuple.error("VecDot with arrays of dimension < 2 is not supported");
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == bool) && (array_dtype_x2 == bool) {
return MsgTuple.error("VecDot with arrays both of type bool is not supported");
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) {
return MsgTuple.error("VecDot with arrays of type bigint is not supported");
}

// @arkouda.registerND(???)
// proc tensorDotMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd1: int, param nd2: int): MsgTuple throws {
// if nd < 3 {
Expand Down
Loading
Loading