From c21bb8cd824209ff110db3f9c0efddc9d6fcdd26 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Tue, 15 Oct 2024 10:17:51 -0600 Subject: [PATCH 01/14] add primitive where-clause evaluation (for instantiateAndRegister) to register_commands.py Signed-off-by: Jeremiah Corrado --- src/CastMsg.chpl | 3 +- src/registry/register_commands.py | 181 ++++++++++++++++++++++++++++-- 2 files changed, 176 insertions(+), 8 deletions(-) diff --git a/src/CastMsg.chpl b/src/CastMsg.chpl index 91453a3164..fda14e61c7 100644 --- a/src/CastMsg.chpl +++ b/src/CastMsg.chpl @@ -25,7 +25,8 @@ module CastMsg { type array_dtype_to, param array_nd: int ): MsgTuple throws - where !(isFloatingType(array_dtype_from) && array_dtype_to == bigint) && + where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) + && array_dtype_to == bigint) && !(array_dtype_from == bigint && array_dtype_to == bool) { const a = st[msgArgs["name"]]: SymEntry(array_dtype_from, array_nd); diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index e7c3436c91..0925d32125 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -339,7 +339,7 @@ def clean_enum_name(name): def stamp_generic_command( - generic_proc_name, prefix, module_name, formals, line_num, is_user_proc + generic_proc_name, prefix, module_name, formals, line_num, iar_annotation ): """ Create code to stamp out and register a generic command using a generic @@ -376,8 +376,8 @@ def stamp_generic_command( stamp_formal_args = ", ".join([f"{k}={v}" for k, v in formals.items()]) - # use qualified naming if generic_proc belongs in a use defined module to avoid name conflicts - call = f"{module_name}.{generic_proc_name}" if is_user_proc else generic_proc_name + # use qualified naming if generic_proc belongs in a user defined module to avoid name conflicts + call = f"{module_name}.{generic_proc_name}" if iar_annotation else generic_proc_name proc = ( f"proc {stamp_name}(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): {RESPONSE_TYPE_NAME} throws do\n" @@ -590,7 +590,9 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries ) -def unpack_generic_symbol_arg(arg_name, symbol_class_name, symbol_count, symbol_param_class): +def unpack_generic_symbol_arg( + arg_name, symbol_class_name, symbol_count, symbol_param_class +): """ Generate the code to unpack a non-array symbol-table entry class (a class that inherits from 'AbstractSymEntry'). @@ -990,8 +992,152 @@ def gen_command_proc(name, return_type, formals, mod_name, config): return (command_proc, cmd_name, is_generic_command, command_formals) +# TODO: use the compiler's built-in support for where-clause evaluation and resolution +# instead of re-implementing it in a much less robust manner here +class WCNode: + def __init__(self, ast): + if isinstance(ast, chapel.OpCall): + if ast.is_binary_op(): + self.node = WCBinOP(ast) + else: + self.node = WCUnaryOP(ast) + elif isinstance(ast, chapel.FnCall): + # 'int(8)' for example should be treated as a literal type name, not a function call + call_name = ast.called_expression().name() + if call_name in chapel_scalar_types.keys(): + self.node = WCLiteral(call_name, list(ast.actuals())[0].text()) + else: + self.node = WCFunc(ast) + else: + self.node = WCLiteral(ast) + + def eval(self, args): + return self.node.eval(args) + + def __str__(self): + return self.node.__str__() + + def __repr__(self): + return self.node.__str__() + + +class WCBinOP(WCNode): + def __init__(self, ast): + self.op = ast.op() + actuals = list(ast.actuals()) + self.lhs = WCNode(actuals[0]) + self.rhs = WCNode(actuals[1]) + + def eval(self, args): + lhse = self.lhs.eval(args) + rhse = self.rhs.eval(args) + match self.op: + case "==": + return lhse == rhse + case "!=": + return lhse != rhse + case "<": + return int(lhse) < int(rhse) + case "<=": + return int(lhse) <= int(rhse) + case ">": + return int(lhse) > int(rhse) + case ">=": + return int(lhse) >= int(rhse) + case "&&": + return bool(lhse) and bool(rhse) + case "||": + return bool(lhse) or bool(rhse) + + def __str__(self): + return f"({self.lhs} {self.op} {self.rhs})" + + +class WCUnaryOP(WCNode): + def __init__(self, ast): + self.op = ast.op() + self.operand = WCNode(list(ast.actuals())[0]) + + def eval(self, args): + match self.op: + case "!": + return not bool(self.operand.eval(args)) + case "-": + return -int(self.operand.eval(args)) + + def __str__(self): + return f"{self.op}{self.operand}" + + +class WCFunc(WCNode): + def __init__(self, ast): + self.name = ast.called_expression().name() + self.actuals = [WCNode(a) for a in list(ast.actuals())] + + def eval(self, args): + # TODO: this is a really bad way to do this. the compiler should be leveraged much more heavily here + if self.name == "isIntegralType": + return self.actuals[0].eval(args) in [ + "int", + "int(8)", + "int(16)", + "int(32)", + "int(64)", + "uint", + "uint(8)", + "uint(16)", + "uint(32)", + "uint(64)", + ] + if self.name == "isRealType": + return self.actuals[0].eval(args) in ["real", "real(32)", "real(64)"] + if self.name == "isComplexType": + return self.actuals[0].eval(args) in [ + "complex", + "complex(64)", + "complex(128)", + ] + if self.name == "isImagType": + return self.actuals[0].eval(args) in ["imag", "imag(32)", "imag(64)"] + else: + error_message( + "evaluating where-clause", + f"general function calls not yet supported in where-clauses; ignoring function: {self.name}", + ) + return True + + def __str__(self): + return f"{self.name}({', '.join([str(a) for a in self.actuals])})" + + +class WCLiteral(WCNode): + def __init__(self, ast, width=None): + if width is not None: + self.value = f"{ast}({width})" + elif isinstance(ast, chapel.Identifier): + self.value = ast.name() + elif isinstance(ast, chapel.IntLiteral): + self.value = ast.text() + elif isinstance(ast, chapel.Dot): + self.value = ast.receiver().name() + "." + ast.field() + # 🥲 + if self.value == "BigInteger.bigint": + self.value = "bigint" + else: + raise ValueError("invalid where-clause literal") + + def eval(self, args): + if self.value in args: + return args[self.value] + else: + return self.value + + def __str__(self): + return self.value + + def stamp_out_command( - config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc + config, formals, name, cmd_prefix, mod_name, line_num, iar_annotation, wc ): """ Yield instantiations of a generic command with using the @@ -1007,15 +1153,28 @@ def stamp_out_command( * cmd_prefix: the prefix to use for the command names * mod_name: the name of the module containing the command procedure (or the user-defined procedure that the command calls) + * line_num: the line number of the annotated procedure + * iar_annotation: a boolean indicating whether the command procedure was annotated with 'instantiateAndRegister' + * wc: the where clause of the annotated procedure The name of the instantiated command will be in the format: 'cmd_prefix' where v1, v2, ... are the values of the generic formals """ formal_perms = generic_permutations(config, formals) + if iar_annotation and wc is not None: + wc_node = WCNode(wc) + print(name, wc_node) + else: + wc_node = None + for fp in formal_perms: + # skip instantiation for this permutation if the where clause evaluates to false + if wcn := wc_node: + if not wcn.eval(fp): + continue stamp = stamp_generic_command( - name, cmd_prefix, mod_name, fp, line_num, is_user_proc + name, cmd_prefix, mod_name, fp, line_num, iar_annotation ) yield stamp @@ -1104,6 +1263,7 @@ def register_commands(config, source_files): mod_name, line_num, False, + fn.where_clause(), ): file_stamps.append(stamp) except ValueError as e: @@ -1143,7 +1303,14 @@ def register_commands(config, source_files): try: for stamp in stamp_out_command( - config, gen_formals, name, command_prefix, mod_name, line_num, True + config, + gen_formals, + name, + command_prefix, + mod_name, + line_num, + True, + fn.where_clause(), ): file_stamps.append(stamp) count += 1 From b5758640c1aa5446f61d48282c4a61161243dfda Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Tue, 15 Oct 2024 11:04:01 -0600 Subject: [PATCH 02/14] remove uninstantiated overloads from commands annotated with instantiateAndRegister Signed-off-by: Jeremiah Corrado --- src/ArgSortMsg.chpl | 8 +----- src/CastMsg.chpl | 20 ------------- src/GenSymIO.chpl | 12 -------- src/IndexingMsg.chpl | 14 --------- src/LinalgMsg.chpl | 48 ++----------------------------- src/RandMsg.chpl | 41 +++++++++----------------- src/ReductionMsg.chpl | 10 ------- src/SetMsg.chpl | 6 ---- src/UtilMsg.chpl | 5 ---- src/registry/register_commands.py | 1 - 10 files changed, 17 insertions(+), 148 deletions(-) diff --git a/src/ArgSortMsg.chpl b/src/ArgSortMsg.chpl index 8e636255f6..9ec94f2b25 100644 --- a/src/ArgSortMsg.chpl +++ b/src/ArgSortMsg.chpl @@ -434,17 +434,11 @@ module ArgSortMsg axis = msgArgs["axis"].toScalar(int), symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype); - + const iv = argsortDefault(vals, algorithm=algorithm, axis); return st.insert(new shared SymEntry(iv)); } - proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) - { - return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string)); - } - proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { const name = msgArgs["name"].toScalar(string), strings = getSegString(name, st), diff --git a/src/CastMsg.chpl b/src/CastMsg.chpl index fda14e61c7..a82e1aeee0 100644 --- a/src/CastMsg.chpl +++ b/src/CastMsg.chpl @@ -15,10 +15,6 @@ module CastMsg { private config const logChannel = ServerConfig.logChannel; const castLogger = new Logger(logLevel, logChannel); - proc isFloatingType(type t) param : bool { - return isRealType(t) || isImagType(t) || isComplexType(t); - } - @arkouda.instantiateAndRegister(prefix="cast") proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_from, @@ -41,22 +37,6 @@ module CastMsg { } } - // cannot cast float types to bigint, cannot cast bigint to bool - proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type array_dtype_from, - type array_dtype_to, - param array_nd: int - ): MsgTuple throws - where (isFloatingType(array_dtype_from) && array_dtype_to == bigint) || - (array_dtype_from == bigint && array_dtype_to == bool) - { - return MsgTuple.error( - "cannot cast array of type %s to %s".format( - type2str(array_dtype_from), - type2str(array_dtype_to) - )); - } - @arkouda.instantiateAndRegister(prefix="castToStrings") proc castArrayToStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws { const name = msgArgs["name"].toScalar(string); diff --git a/src/GenSymIO.chpl b/src/GenSymIO.chpl index 75a83bc1d7..9efa25b31c 100644 --- a/src/GenSymIO.chpl +++ b/src/GenSymIO.chpl @@ -43,12 +43,6 @@ module GenSymIO { return st.insert(new shared SymEntry(makeArrayFromBytes(msgArgs.payload, shape, array_dtype))); } - proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == bigint - { - return MsgTuple.error("Array creation from binary payload is not supported for bigint arrays"); - } - proc makeArrayFromBytes(ref payload: bytes, shape: ?N*int, type t): [] t throws { var size = 1; for s in shape do size *= s; @@ -138,12 +132,6 @@ module GenSymIO { return MsgTuple.payload(bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size)); } - proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == bigint - { - return MsgTuple.error("cannot create ndarray from bigint array"); - } - /* * Utility proc to test casting a string to a specified type * :arg c: String to cast diff --git a/src/IndexingMsg.chpl b/src/IndexingMsg.chpl index 96a65a9256..be7c646dc0 100644 --- a/src/IndexingMsg.chpl +++ b/src/IndexingMsg.chpl @@ -206,12 +206,6 @@ module IndexingMsg } } - proc multiPDArrayIndex(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_a, type array_dtype_idx, param array_nd: int): MsgTuple throws - where array_dtype_idx != int && array_dtype_idx != uint - { - return MsgTuple.error("Invalid index type: %s; must be 'int' or 'uint'".format(type2str(array_dtype_idx))); - } - private proc multiIndexShape(inShape: ?N*int, idxDims: [?d] int, outSize: int): (bool, int, N*int) { var minShape: N*int = inShape, firstRank = -1; @@ -960,14 +954,6 @@ module IndexingMsg return st.insert(new shared SymEntry(y, x.max_bits)); } - proc takeAlongAxis(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type array_dtype_x, - type array_dtype_idx, - param array_nd: int - ): MsgTuple throws { - return MsgTuple.error("Cannot take along axis with non-integer index array"); - } - use CommandMap; registerFunction("arrayViewMixedIndex", arrayViewMixedIndexMsg, getModuleName()); registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName()); diff --git a/src/LinalgMsg.chpl b/src/LinalgMsg.chpl index 855d35cf77..89d22169c4 100644 --- a/src/LinalgMsg.chpl +++ b/src/LinalgMsg.chpl @@ -61,13 +61,6 @@ module LinalgMsg { return st.insert(e); } - - proc eye(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("eye does not support the bigint dtype"); - } - // tril and triu are identical except for the argument they pass to triluHandler (true for upper, false for lower) // The zeros are written into the upper (or lower) triangle of the array, offset by the value of diag. @@ -79,11 +72,6 @@ module LinalgMsg { return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, false); } - proc tril(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_nd < 2 { - return MsgTuple.error("Array must be at least 2 dimensional for 'tril'"); - } - // Create an array from an existing array with its lower triangle zeroed out @arkouda.instantiateAndRegister @@ -92,13 +80,9 @@ module LinalgMsg { return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, true); } - proc triu(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_nd < 2 { - return MsgTuple.error("Array must be at least 2 dimensional for 'triu'"); - } - - // Fetch the arguments, call zeroTri, return result. - + // Fetch the arguments, call zeroTri, return result. + // TODO: support instantiating param bools with 'true' and 'false' s.t. we'd have 'triluHandler' and 'triluHandler' + // cmds if this procedure were annotated instead of the two above. proc triluHandler(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int, param upper: bool ): MsgTuple throws { @@ -195,16 +179,6 @@ module LinalgMsg { } - proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_nd < 2) && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) { - return MsgTuple.error("Matrix multiplication with arrays of dimension < 2 is not supported"); - } - - proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) { - return MsgTuple.error("Matrix multiplication with arrays of bigint type is not supported"); - } - proc compute_result_type_matmul(type t1, type t2) type { if t1 == real || t2 == real then return real; if t1 == int || t2 == int then return int; @@ -366,22 +340,6 @@ module LinalgMsg { return bool; } - proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_nd < 2) && ((array_dtype_x1 != bool) || (array_dtype_x2 != bool)) - && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) { - return MsgTuple.error("VecDot with arrays of dimension < 2 is not supported"); - } - - proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_dtype_x1 == bool) && (array_dtype_x2 == bool) { - return MsgTuple.error("VecDot with arrays both of type bool is not supported"); - } - - proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) { - return MsgTuple.error("VecDot with arrays of type bigint is not supported"); - } - // @arkouda.registerND(???) // proc tensorDotMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd1: int, param nd2: int): MsgTuple throws { // if nd < 3 { diff --git a/src/RandMsg.chpl b/src/RandMsg.chpl index 3d2e9b198b..69fafb30d1 100644 --- a/src/RandMsg.chpl +++ b/src/RandMsg.chpl @@ -75,12 +75,6 @@ module RandMsg return st.insert(e); } - proc randint(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("randint does not support the bigint dtype"); - } - @arkouda.instantiateAndRegister proc randomNormal(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd: int): MsgTuple throws { const shape = msgArgs["shape"].toScalarTuple(int, array_nd), @@ -117,12 +111,6 @@ module RandMsg return st.insert(new shared GeneratorSymEntry(generator, state)); } - proc createGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("createGenerator does not support the bigint dtype"); - } - @arkouda.instantiateAndRegister proc uniformGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws where array_dtype != BigInteger.bigint @@ -151,12 +139,6 @@ module RandMsg return st.insert(uniformEntry); } - proc uniformGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("uniformGenerator does not support the bigint dtype"); - } - /* Use the ziggurat method (https://en.wikipedia.org/wiki/Ziggurat_algorithm#Theory_of_operation) @@ -252,6 +234,9 @@ module RandMsg @arkouda.instantiateAndRegister proc standardNormalGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + do return standardNormalGeneratorHelp(cmd, msgArgs, st, array_nd); + + proc standardNormalGeneratorHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd == 1 { const name = msgArgs["name"], // generator name @@ -287,7 +272,7 @@ module RandMsg } - proc standardNormalGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + proc standardNormalGeneratorHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd > 1 { const name = msgArgs["name"], // generator name @@ -387,6 +372,9 @@ module RandMsg @arkouda.instantiateAndRegister proc standardExponential(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + do return standardExponentialHelp(cmd, msgArgs, st, array_nd); + + proc standardExponentialHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd == 1 { const name = msgArgs["name"], // generator name @@ -421,7 +409,7 @@ module RandMsg } } - proc standardExponential(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + proc standardExponentialHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd > 1 { const name = msgArgs["name"], // generator name @@ -567,12 +555,6 @@ module RandMsg } } - proc choice(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("choice does not support the bigint dtype"); - } - inline proc logisticGenerator(mu: real, scale: real, ref rs) { var U = rs.next(0, 1); @@ -693,7 +675,10 @@ module RandMsg } @arkouda.instantiateAndRegister - proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + do return shuffleHelp(cmd, msgArgs, st, array_dtype, array_nd); + + proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws where array_nd == 1 { const name = msgArgs["name"], @@ -715,7 +700,7 @@ module RandMsg return MsgTuple.success(); } - proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws where array_nd != 1 { const name = msgArgs["name"], diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index dd69bf526e..febaed4866 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -380,16 +380,6 @@ module ReductionMsg // simple and efficient 'nonzero' implementation for 1D arrays - proc nonzero( - cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type array_dtype, - param array_nd: int - ): MsgTuple throws - where array_dtype == bigint - { - return MsgTuple.error("nonzero is not supported for bigint arrays"); - } - proc nonzero1D(x: [?d] ?t): [] int throws { const nTasksPerLoc = here.maxTaskPar; var nnzPerTask: [0.. Date: Tue, 15 Oct 2024 11:07:23 -0600 Subject: [PATCH 03/14] remove match statements from register_commands.py Signed-off-by: Jeremiah Corrado --- src/registry/register_commands.py | 55 ++++++++++++++++++------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index b8d2a3ccb0..7d78b78737 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -1031,23 +1031,29 @@ def __init__(self, ast): def eval(self, args): lhse = self.lhs.eval(args) rhse = self.rhs.eval(args) - match self.op: - case "==": - return lhse == rhse - case "!=": - return lhse != rhse - case "<": - return int(lhse) < int(rhse) - case "<=": - return int(lhse) <= int(rhse) - case ">": - return int(lhse) > int(rhse) - case ">=": - return int(lhse) >= int(rhse) - case "&&": - return bool(lhse) and bool(rhse) - case "||": - return bool(lhse) or bool(rhse) + + if self.op == "==": + return lhse == rhse + elif self.op == "!=": + return lhse != rhse + elif self.op == "<": + return int(lhse) < int(rhse) + elif self.op == "<=": + return int(lhse) <= int(rhse) + elif self.op == ">": + return int(lhse) > int(rhse) + elif self.op == ">=": + return int(lhse) >= int(rhse) + elif self.op == "&&": + return bool(lhse) and bool(rhse) + elif self.op == "||": + return bool(lhse) or bool(rhse) + else: + error_message( + "evaluating where-clause", + f"binary operator '{self.op}' not yet supported in where-clauses", + ) + return True def __str__(self): return f"({self.lhs} {self.op} {self.rhs})" @@ -1059,11 +1065,16 @@ def __init__(self, ast): self.operand = WCNode(list(ast.actuals())[0]) def eval(self, args): - match self.op: - case "!": - return not bool(self.operand.eval(args)) - case "-": - return -int(self.operand.eval(args)) + if self.op == "!": + return not bool(self.operand.eval(args)) + elif self.op == "-": + return -int(self.operand.eval(args)) + else: + error_message( + "evaluating where-clause", + f"unary operator '{self.op}' not yet supported in where-clauses", + ) + return True def __str__(self): return f"{self.op}{self.operand}" From f4a9d4b1532287a2cd34c96cf19e633b43c463db Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Tue, 15 Oct 2024 13:40:01 -0600 Subject: [PATCH 04/14] add error message for 'pad' w/ bigint Signed-off-by: Jeremiah Corrado --- arkouda/array_api/utility_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arkouda/array_api/utility_functions.py b/arkouda/array_api/utility_functions.py index 05eaf907d9..be24fde1fb 100644 --- a/arkouda/array_api/utility_functions.py +++ b/arkouda/array_api/utility_functions.py @@ -146,6 +146,9 @@ def pad( if mode != "constant": raise NotImplementedError(f"pad mode '{mode}' is not supported") + if array.dtype == ak.bigint: + raise RuntimeError("Error executing command: pad does not support dtype bigint") + if "constant_values" not in kwargs: cvals = 0 else: From 16395f1af302e79419774fae527119441173e3c0 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Tue, 15 Oct 2024 15:01:32 -0600 Subject: [PATCH 05/14] add preliminary support for where-clause evaluation in 'registerCommand' Signed-off-by: Jeremiah Corrado --- src/AryUtil.chpl | 12 +--- src/LinalgMsg.chpl | 5 -- src/MsgProcessing.chpl | 10 --- src/SortMsg.chpl | 28 ++++---- src/SparseMatrixMsg.chpl | 2 +- src/StatsMsg.chpl | 18 ------ src/UtilMsg.chpl | 10 --- src/registry/register_commands.py | 103 +++++++++++++++++++++--------- 8 files changed, 89 insertions(+), 99 deletions(-) diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index 71fa35100a..a4cec46bc5 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -945,9 +945,9 @@ module AryUtil flatten a multi-dimensional array into a 1D array */ @arkouda.registerCommand - proc flatten(const ref a: [?d] ?t): [] t throws - where a.rank > 1 - { + proc flatten(const ref a: [?d] ?t): [] t throws { + if a.rank == 1 then return a; + var flat = makeDistArray(d.size, t); // ranges of flat indices owned by each locale @@ -1004,12 +1004,6 @@ module AryUtil return flat; } - proc flatten(const ref a: [?d] ?t): [] t throws - where a.rank == 1 - { - return a; - } - // helper for computing an array element's index from its order record orderer { param rank: int; diff --git a/src/LinalgMsg.chpl b/src/LinalgMsg.chpl index 89d22169c4..a4b6d4e303 100644 --- a/src/LinalgMsg.chpl +++ b/src/LinalgMsg.chpl @@ -276,11 +276,6 @@ module LinalgMsg { return ret; } - proc transpose(array: [?d] ?t): [d] t throws - where d.rank < 2 { - throw new Error("Matrix transpose with arrays of dimension < 2 is not supported"); - } - /* Compute the generalized dot product of two tensors along the specified axis. diff --git a/src/MsgProcessing.chpl b/src/MsgProcessing.chpl index df95f5d7b0..6917822881 100644 --- a/src/MsgProcessing.chpl +++ b/src/MsgProcessing.chpl @@ -339,11 +339,6 @@ module MsgProcessing return msg; } - proc chunkInfoAsString(array: [?d] ?t): string throws - where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){ - throw new Error("chunkInfo does not support dtype %s".format(t:string)); - } - @arkouda.registerCommand proc chunkInfoAsArray(array: [?d] ?t):[] int throws where (t == bool) || (t == int(64)) || (t == uint(64)) || (t == uint(8)) ||(t == real) { @@ -357,9 +352,4 @@ module MsgProcessing } return blockSizes; } - - proc chunkInfoAsArray(array: [?d] ?t): [d] int throws - where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){ - throw new Error("chunkInfo does not support dtype %s".format(t:string)); - } } diff --git a/src/SortMsg.chpl b/src/SortMsg.chpl index 5b26a03596..a4516e63c5 100644 --- a/src/SortMsg.chpl +++ b/src/SortMsg.chpl @@ -31,9 +31,13 @@ module SortMsg /* sort takes pdarray and returns a sorted copy of the array */ @arkouda.registerCommand - proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws - where ((t == real) || (t == int) || (t == uint(64))) && (d.rank == 1) { + proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws + where ((t == real) || (t == int) || (t == uint(64))) + do return sortHelp(array, alg, axis); + proc sortHelp(array: [?d] ?t, alg: string, axis: int): [d] t throws + where d.rank == 1 + { var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg); const itemsize = dtypeSize(whichDtype(t)); overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize)); @@ -48,9 +52,9 @@ module SortMsg } } - proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws - where ((t == real) || (t==int) || (t==uint(64))) && (d.rank > 1) { - + proc sortHelp(array: [?d] ?t, alg: string, axis: int): [d] t throws + where d.rank > 1 + { var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg); const itemsize = dtypeSize(whichDtype(t)); overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize)); @@ -91,16 +95,11 @@ module SortMsg return sorted; } - proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws - where ((t != real) && (t!=int) && (t!=uint(64))) { - throw new Error("sort does not support type %s".format(type2str(t))); - } - // https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted @arkouda.registerCommand proc searchSorted(x1: [?d1] real, x2: [?d2] real, side: string): [d2] int throws - where (d1.rank == 1) { - + where d1.rank == 1 + { if side != "left" && side != "right" { throw new Error("searchSorted side must be a string with value 'left' or 'right'."); } @@ -123,11 +122,6 @@ module SortMsg return ret; } - proc searchSorted(x1: [?d1] real, x2: [?d2] real, side: string): [d2] int throws - where (d1.rank != 1){ - throw new Error("searchSorted only arrays x1 of dimension 1."); - } - record leftCmp: relativeComparator { proc compare(a: real, b: real): int { if a < b then return -1; diff --git a/src/SparseMatrixMsg.chpl b/src/SparseMatrixMsg.chpl index dff40372a8..7ec09af880 100644 --- a/src/SparseMatrixMsg.chpl +++ b/src/SparseMatrixMsg.chpl @@ -63,7 +63,7 @@ module SparseMatrixMsg { return MsgTuple.fromResponses(responses); } - @arkouda.registerCommand("fill_sparse_vals") + @arkouda.registerCommand("fill_sparse_vals", ignoreWhereClause=true) proc fillSparseMatrixMsg(matrix: borrowed SparseSymEntry(?), vals: [?d] ?t /* matrix.etype */) throws where t == matrix.etype && d.rank == 1 do fillSparseMatrix(matrix.a, vals, matrix.matLayout); diff --git a/src/StatsMsg.chpl b/src/StatsMsg.chpl index 05a9bda2b9..1e67e55fca 100644 --- a/src/StatsMsg.chpl +++ b/src/StatsMsg.chpl @@ -85,15 +85,6 @@ module StatsMsg { return (+ reduce ((x:real - mx) * (y:real - my))) / (dx.size - 1):real; } - // above registration will instantiate `cov` for all combinations of array ranks - // even though it is only valid when the ranks are the same - // (respecting the where clause in the signature is future work for 'registerCommand') - proc cov(const ref x: [?dx], const ref y: [?dy]): real throws - where dx.rank != dy.rank - { - throw new Error("x and y must have the same rank"); - } - @arkouda.registerCommand() proc corr(const ref x: [?dx] ?tx, const ref y: [?dy] ?ty): real throws where dx.rank == dy.rank @@ -107,15 +98,6 @@ module StatsMsg { return cov(x, y) / (std(x, 1) * std(y, 1)); } - // above registration will instantiate `corr` for all combinations of array ranks - // even though it is only valid when the ranks are the same - // (respecting the where clause in the signature is future work for 'registerCommand') - proc corr(const ref x: [?dx], const ref y: [?dy]): real throws - where dx.rank != dy.rank - { - throw new Error("x and y must have the same rank"); - } - @arkouda.registerCommand() proc cumSum(const ref x: [?d] ?t, axis: int, includeInitial: bool): [] t throws { if d.rank == 1 { diff --git a/src/UtilMsg.chpl b/src/UtilMsg.chpl index 2f5fb1e054..20ded2c7f9 100644 --- a/src/UtilMsg.chpl +++ b/src/UtilMsg.chpl @@ -43,11 +43,6 @@ module UtilMsg { return y; } - proc clip(const ref x: [?d] ?t, min: real, max: real): [d] t throws - where (t != int) && (t != real) && (t != uint(8)) && (t != uint(64)){ - throw new Error("clip does not support dtype %s".format(t:string)); - } - /* Compute the n'th order discrete difference along a given axis @@ -95,11 +90,6 @@ module UtilMsg { } } - proc diff(x: [?d] ?t, n: int, axis: int): [d] t throws - where (t != real) && (t != int) && (t != uint(8)) && (t != uint(64)){ - throw new Error("diff does not support dtype %s".format(t:string)); - } - // helper to create a domain that's 'n' elements smaller in the 'axis' dimension private proc subDomain(shape: ?N*int, axis: int, n: int) { var rngs: N*range; diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index 7d78b78737..768b2ee6f5 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -7,7 +7,7 @@ DEFAULT_MODS = ["MsgProcessing", "GenSymIO"] -registerAttr = ("arkouda.registerCommand", ["name"]) +registerAttr = ("arkouda.registerCommand", ["name", "ignoreWhereClause"]) instAndRegisterAttr = ("arkouda.instantiateAndRegister", ["prefix"]) # chapel types and their numpy equivalents @@ -105,6 +105,9 @@ def __init__(self, name): def name(self): return self.name + def __str__(self): + return f"?{self.name}" + class FormalQueryRef: def __init__(self, name): @@ -113,6 +116,9 @@ def __init__(self, name): def name(self): return self.name + def __str__(self): + return f"QRef: '{self.name}'" + class StaticTypeInfo: def __init__(self, value): @@ -121,6 +127,9 @@ def __init__(self, value): def value(self): return self.value + def __str__(self): + return f"static: '{self.value}'" + class formalKind(Enum): ARRAY = 1 @@ -170,6 +179,9 @@ def stringify(self) -> str: else f"{self.storage_kind} {self.name}" ) + def __str__(self): + return f"{self.kind} [{self.storage_kind} {self.name}: {self.type_str}] (info: {self.info})" + def get_formals(fn, require_type_annotations): """ @@ -741,7 +753,10 @@ def gen_arg_unpacking(formals, config): """ Generate argument unpacking code for a message handler procedure - Returns the chapel code to unpack the arguments, and a list of generic arguments + Returns a tuple containing: + * the chapel code to unpack the arguments + * a list of generic arguments + * a map of array domain/type queries to their corresponding generic arguments """ unpack_lines = [] generic_args = [] @@ -852,7 +867,11 @@ def gen_arg_unpacking(formals, config): generic_args += scalar_args scalar_arg_counter += 1 - return ("\n".join(unpack_lines), generic_args) + return ( + "\n".join(unpack_lines), + generic_args, + {**array_domain_queries, **array_dtype_queries}, + ) def gen_user_function_call(name, arg_names, mod_name, user_rt): @@ -926,8 +945,8 @@ def gen_command_proc(name, return_type, formals, mod_name, config): * the chapel code for the command procedure * the name of the command procedure * a boolean indicating whether the command has generic (param/type) formals - * a list of tuples in the format (name, storage kind, type expression) - representing the generic formals of the command procedure + * a list of FormalTypeSpec representing the command procedure's generic formals + * a table of domain/type queries used in array formals mapped to their respective generic arguments proc (cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, ): MsgTuple throws { () @@ -938,7 +957,7 @@ def gen_command_proc(name, return_type, formals, mod_name, config): """ - arg_unpack, command_formals = gen_arg_unpacking(formals, config) + arg_unpack, command_formals, query_table = gen_arg_unpacking(formals, config) is_generic_command = len(command_formals) > 0 signature, cmd_name = gen_signature(name, command_formals) fn_call, result_name = gen_user_function_call( @@ -989,7 +1008,7 @@ def gen_command_proc(name, return_type, formals, mod_name, config): [signature, arg_unpack, fn_call, symbol_creation, response, "}"] ) - return (command_proc, cmd_name, is_generic_command, command_formals) + return (command_proc, cmd_name, is_generic_command, command_formals, query_table) # TODO: use the compiler's built-in support for where-clause evaluation and resolution @@ -1011,8 +1030,8 @@ def __init__(self, ast): else: self.node = WCLiteral(ast) - def eval(self, args): - return self.node.eval(args) + def eval(self, args, translation_table=None): + return self.node.eval(args, translation_table) def __str__(self): return self.node.__str__() @@ -1028,14 +1047,14 @@ def __init__(self, ast): self.lhs = WCNode(actuals[0]) self.rhs = WCNode(actuals[1]) - def eval(self, args): - lhse = self.lhs.eval(args) - rhse = self.rhs.eval(args) + def eval(self, args, translation_table=None): + lhse = self.lhs.eval(args, translation_table) + rhse = self.rhs.eval(args, translation_table) if self.op == "==": - return lhse == rhse + return str(lhse) == str(rhse) elif self.op == "!=": - return lhse != rhse + return str(lhse) != str(rhse) elif self.op == "<": return int(lhse) < int(rhse) elif self.op == "<=": @@ -1064,11 +1083,11 @@ def __init__(self, ast): self.op = ast.op() self.operand = WCNode(list(ast.actuals())[0]) - def eval(self, args): + def eval(self, args, translation_table=None): if self.op == "!": - return not bool(self.operand.eval(args)) + return not bool(self.operand.eval(args, translation_table)) elif self.op == "-": - return -int(self.operand.eval(args)) + return -int(self.operand.eval(args, translation_table)) else: error_message( "evaluating where-clause", @@ -1085,10 +1104,10 @@ def __init__(self, ast): self.name = ast.called_expression().name() self.actuals = [WCNode(a) for a in list(ast.actuals())] - def eval(self, args): + def eval(self, args, translation_table=None): # TODO: this is a really bad way to do this. the compiler should be leveraged much more heavily here if self.name == "isIntegralType": - return self.actuals[0].eval(args) in [ + return self.actuals[0].eval(args, translation_table) in [ "int", "int(8)", "int(16)", @@ -1101,15 +1120,23 @@ def eval(self, args): "uint(64)", ] if self.name == "isRealType": - return self.actuals[0].eval(args) in ["real", "real(32)", "real(64)"] + return self.actuals[0].eval(args, translation_table) in [ + "real", + "real(32)", + "real(64)", + ] if self.name == "isComplexType": - return self.actuals[0].eval(args) in [ + return self.actuals[0].eval(args, translation_table) in [ "complex", "complex(64)", "complex(128)", ] if self.name == "isImagType": - return self.actuals[0].eval(args) in ["imag", "imag(32)", "imag(64)"] + return self.actuals[0].eval(args, translation_table) in [ + "imag", + "imag(32)", + "imag(64)", + ] else: error_message( "evaluating where-clause", @@ -1134,12 +1161,16 @@ def __init__(self, ast, width=None): # 🥲 if self.value == "BigInteger.bigint": self.value = "bigint" + if self.value.endswith(".rank"): + self.value = self.value.split(".")[0] else: raise ValueError("invalid where-clause literal") - def eval(self, args): + def eval(self, args, translation_table=None): if self.value in args: return args[self.value] + elif translation_table is not None and self.value in translation_table: + return args[translation_table[self.value]] else: return self.value @@ -1148,7 +1179,15 @@ def __str__(self): def stamp_out_command( - config, formals, name, cmd_prefix, mod_name, line_num, iar_annotation, wc + config, + formals, + name, + cmd_prefix, + mod_name, + line_num, + iar_annotation, + wc, + query_table=None, ): """ Yield instantiations of a generic command with using the @@ -1167,13 +1206,14 @@ def stamp_out_command( * line_num: the line number of the annotated procedure * iar_annotation: a boolean indicating whether the command procedure was annotated with 'instantiateAndRegister' * wc: the where clause of the annotated procedure + * query_table: a dictionary mapping query names to their corresponding generic formal names The name of the instantiated command will be in the format: 'cmd_prefix' where v1, v2, ... are the values of the generic formals """ formal_perms = generic_permutations(config, formals) - if iar_annotation and wc is not None: + if wc is not None: wc_node = WCNode(wc) else: wc_node = None @@ -1181,7 +1221,7 @@ def stamp_out_command( for fp in formal_perms: # skip instantiation for this permutation if the where clause evaluates to false if wcn := wc_node: - if not wcn.eval(fp): + if not wcn.eval(fp, query_table): continue stamp = stamp_generic_command( name, cmd_prefix, mod_name, fp, line_num, iar_annotation @@ -1248,6 +1288,10 @@ def register_commands(config, source_files): else: command_prefix = name + ignore_where_clause = False + if iwc := attr_call["ignoreWhereClause"]: + ignore_where_clause = bool(iwc.value()) + if len(gen_formals) > 0: error_message( f"registering '{name}'", @@ -1256,8 +1300,8 @@ def register_commands(config, source_files): ) continue - (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals) = gen_command_proc( - name, fn.return_type(), con_formals, mod_name, config + (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = ( + gen_command_proc(name, fn.return_type(), con_formals, mod_name, config) ) file_stamps.append(cmd_proc) @@ -1273,7 +1317,8 @@ def register_commands(config, source_files): mod_name, line_num, False, - fn.where_clause(), + fn.where_clause() if not ignore_where_clause else None, + query_table, ): file_stamps.append(stamp) except ValueError as e: From 71cc4ba7e465c72d6946f534aaff9b7caef42ed8 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Wed, 16 Oct 2024 09:58:31 -0600 Subject: [PATCH 06/14] fix misinterpretation of explicit type widths (e.g., uint(64)) in where clauses Signed-off-by: Jeremiah Corrado --- arkouda/array_api/utility_functions.py | 6 +++ src/CastMsg.chpl | 3 +- src/registry/register_commands.py | 64 ++++++++++++++------------ 3 files changed, 42 insertions(+), 31 deletions(-) diff --git a/arkouda/array_api/utility_functions.py b/arkouda/array_api/utility_functions.py index be24fde1fb..0fcac23590 100644 --- a/arkouda/array_api/utility_functions.py +++ b/arkouda/array_api/utility_functions.py @@ -68,6 +68,9 @@ def clip(a: Array, a_min, a_max, /) -> Array: a_max : scalar The maximum value """ + if a.dtype == ak.bigint or a.dtype == ak.bool_: + raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}") + return Array._new( create_pdarray( generic_msg( @@ -99,6 +102,9 @@ def diff(a: Array, /, n: int = 1, axis: int = -1, prepend=None, append=None) -> append : Array, optional Array to append to `a` along `axis` before calculating the difference. """ + if a.dtype == ak.bigint or a.dtype == ak.bool_: + raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}") + if prepend is not None and append is not None: a_ = concat((prepend, a, append), axis=axis) elif prepend is not None: diff --git a/src/CastMsg.chpl b/src/CastMsg.chpl index a82e1aeee0..b1417b7df7 100644 --- a/src/CastMsg.chpl +++ b/src/CastMsg.chpl @@ -21,8 +21,7 @@ module CastMsg { type array_dtype_to, param array_nd: int ): MsgTuple throws - where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) - && array_dtype_to == bigint) && + where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) && array_dtype_to == bigint) && !(array_dtype_from == bigint && array_dtype_to == bool) { const a = st[msgArgs["name"]]: SymEntry(array_dtype_from, array_nd); diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index 768b2ee6f5..825f5f24a3 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -1106,36 +1106,34 @@ def __init__(self, ast): def eval(self, args, translation_table=None): # TODO: this is a really bad way to do this. the compiler should be leveraged much more heavily here + arg = self.actuals[0].eval(args, translation_table) if self.name == "isIntegralType": - return self.actuals[0].eval(args, translation_table) in [ - "int", - "int(8)", - "int(16)", - "int(32)", - "int(64)", - "uint", - "uint(8)", - "uint(16)", - "uint(32)", - "uint(64)", + return arg in [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", ] - if self.name == "isRealType": - return self.actuals[0].eval(args, translation_table) in [ - "real", - "real(32)", - "real(64)", + elif self.name == "isRealType": + return arg in [ + "float32", + "float64", ] - if self.name == "isComplexType": - return self.actuals[0].eval(args, translation_table) in [ + elif self.name == "isComplexType": + return arg in [ "complex", - "complex(64)", - "complex(128)", + "complex64", + "complex128", ] - if self.name == "isImagType": - return self.actuals[0].eval(args, translation_table) in [ + elif self.name == "isImagType": + return arg in [ "imag", - "imag(32)", - "imag(64)", + "imag32", + "imag64", ] else: error_message( @@ -1148,12 +1146,20 @@ def __str__(self): return f"{self.name}({', '.join([str(a) for a in self.actuals])})" +def canonicalize_type_name(name): + if name in chapel_scalar_types: + return chapel_scalar_types[name] + else: + return name + + class WCLiteral(WCNode): def __init__(self, ast, width=None): + # note: scalar type names are canonicalized to ensure 'int' == 'int(64)' (for example) if width is not None: - self.value = f"{ast}({width})" + self.value = canonicalize_type_name(f"{ast}({width})") elif isinstance(ast, chapel.Identifier): - self.value = ast.name() + self.value = canonicalize_type_name(ast.name()) elif isinstance(ast, chapel.IntLiteral): self.value = ast.text() elif isinstance(ast, chapel.Dot): @@ -1162,15 +1168,15 @@ def __init__(self, ast, width=None): if self.value == "BigInteger.bigint": self.value = "bigint" if self.value.endswith(".rank"): - self.value = self.value.split(".")[0] + self.value = self.value.split(".")[0] # ex: d1.rank -> d1 else: raise ValueError("invalid where-clause literal") def eval(self, args, translation_table=None): if self.value in args: - return args[self.value] + return canonicalize_type_name(args[self.value]) elif translation_table is not None and self.value in translation_table: - return args[translation_table[self.value]] + return canonicalize_type_name(args[translation_table[self.value]]) else: return self.value From 90f50b04eddbeb255a614efa71f60e1a5f7d17d5 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Wed, 16 Oct 2024 10:20:14 -0600 Subject: [PATCH 07/14] fix error message Signed-off-by: Jeremiah Corrado --- arkouda/array_api/utility_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arkouda/array_api/utility_functions.py b/arkouda/array_api/utility_functions.py index 0fcac23590..ea8c2b1e3b 100644 --- a/arkouda/array_api/utility_functions.py +++ b/arkouda/array_api/utility_functions.py @@ -69,7 +69,7 @@ def clip(a: Array, a_min, a_max, /) -> Array: The maximum value """ if a.dtype == ak.bigint or a.dtype == ak.bool_: - raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}") + raise RuntimeError(f"Error executing command: diff clip not support dtype {a.dtype}") return Array._new( create_pdarray( From e2e9a7b823e452290aa885b3e873bb2293b98b04 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Wed, 16 Oct 2024 10:55:03 -0600 Subject: [PATCH 08/14] fix error message Signed-off-by: Jeremiah Corrado --- arkouda/array_api/utility_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arkouda/array_api/utility_functions.py b/arkouda/array_api/utility_functions.py index ea8c2b1e3b..ea79f4eb15 100644 --- a/arkouda/array_api/utility_functions.py +++ b/arkouda/array_api/utility_functions.py @@ -69,7 +69,7 @@ def clip(a: Array, a_min, a_max, /) -> Array: The maximum value """ if a.dtype == ak.bigint or a.dtype == ak.bool_: - raise RuntimeError(f"Error executing command: diff clip not support dtype {a.dtype}") + raise RuntimeError(f"Error executing command: clip does not support dtype {a.dtype}") return Array._new( create_pdarray( From 68ae907c845b8d124ab2216ee974d5306c3bc58f Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Thu, 17 Oct 2024 14:18:16 -0600 Subject: [PATCH 09/14] refactor 'sum' to not rely on where-clauses for dispatching Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 7 +++---- src/AryUtil.chpl | 8 ++++++++ src/ReductionMsg.chpl | 32 +++++--------------------------- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 51bcccf24b..a69045b5b5 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -2757,12 +2757,11 @@ def sum( RuntimeError Raised if there's a server-side error thrown """ - axis_arry = _get_axis_pdarray(axis) repMsg = generic_msg( - cmd=f"sum<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>", - args={"x": pda, "axis": axis_arry, "skipNan": False}, + cmd=f"sum<{pda.dtype.name},{pda.ndim}>", + args={"x": pda, "axis": axis, "skipNan": False}, ) - if axis is None or len(axis_arry) == 0 or pda.ndim == 1: + if axis is None or len(axis) == 0 or pda.ndim == 1: return create_pdarray(cast(str, repMsg)).flatten()[0] else: return create_pdarray(cast(str, repMsg)) diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index a5b30bc99b..470d5bffd7 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -156,6 +156,10 @@ module AryUtil return (true, ret); } + proc validateNegativeAxes(axes: list(int), param nd: int): (bool, list(int)) { + return new list(validateNegativeAxes(axes.toArray(), nd)); + } + /* Get a domain that selects out the idx'th set of indices along the specified axes @@ -328,6 +332,10 @@ module AryUtil return ret; } + proc reducedShape(shape: ?N*int, axes: list(int)): N*int { + return reducedShape(shape, axes.toArray()); + } + /* Returns stats on a given array in form (int,int,real,real,real). diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index 17dadb142c..0de6ad69a0 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -7,6 +7,7 @@ module ReductionMsg use Reflection only; use CommAggregation; use BigInteger; + use List; use MultiTypeSymbolTable; use MultiTypeSymEntry; @@ -123,25 +124,13 @@ module ReductionMsg } } - - - @arkouda.registerCommand - proc sum(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] t throws - where (t==int || t==real || t==uint(64)) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(+ reduce x)]); - } - - proc sum(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] int throws - where (t==bool) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(+ reduce x:int)]); - } - - proc sum(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] throws - where (t==int || t==real || t==uint(64) || t==bool) && (x.rank != 1) && (axis.rank == 1) { + proc sum(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + where t==int || t==real || t==uint(64) || t==bool + { use SliceReductionOps; - type opType = if t == bool then int else t; + if d.rank == 1 then return (+ reduce x:opType); const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { @@ -149,7 +138,6 @@ module ReductionMsg } else { const outShape = reducedShape(x.shape, axes); var ret = makeDistArray((...outShape), opType); - if (ret.size==1) { ret[ret.domain.low] = (+ reduce x:opType); }else{ @@ -160,16 +148,6 @@ module ReductionMsg } return ret; } - } - - proc sum(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t==int || t==real || t==uint(64) || t==bool) && (axis.rank != 1) { - throw new Error("sum only accepts axis of rank 1."); - } - - proc sum(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t!=int && t!=real && t!=uint(64) && t!=bool) { - throw new Error("sum does not support type %s".format(type2str(t))); } @arkouda.registerCommand From 0142daa71caaba83e9a86f41865da8e64360e58e Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Thu, 17 Oct 2024 15:41:55 -0600 Subject: [PATCH 10/14] refactor prod, max, and min to take their 'axis' argument as a list Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 38 ++++++++------------ src/ReductionMsg.chpl | 79 ++++++----------------------------------- 2 files changed, 26 insertions(+), 91 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index a69045b5b5..81e97b6e35 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -2720,16 +2720,6 @@ def is_sorted(pda: pdarray) -> np.bool_: ) -def _get_axis_pdarray(axis: Optional[Union[int, Tuple[int, ...]]] = None): - from arkouda import array as ak_array - - axis_list = [] - if axis is not None: - axis_list = list(axis) if isinstance(axis, tuple) else [axis] - - return ak_array(axis_list, dtype="int64") - - @typechecked def sum( pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None @@ -2757,11 +2747,13 @@ def sum( RuntimeError Raised if there's a server-side error thrown """ + axis_ = tuple(axis) if axis is not None else () repMsg = generic_msg( cmd=f"sum<{pda.dtype.name},{pda.ndim}>", args={"x": pda, "axis": axis, "skipNan": False}, ) - if axis is None or len(axis) == 0 or pda.ndim == 1: + if axis is None or len(axis_) == 0 or pda.ndim == 1: + # TODO: remove call to 'flatten' return create_pdarray(cast(str, repMsg)).flatten()[0] else: return create_pdarray(cast(str, repMsg)) @@ -2844,12 +2836,12 @@ def prod(pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Un RuntimeError Raised if there's a server-side error thrown """ - axis_arry = _get_axis_pdarray(axis) + axis_ = tuple(axis) if axis is not None else () repMsg = generic_msg( - cmd=f"prod<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>", - args={"x": pda, "axis": axis_arry, "skipNan": False}, + cmd=f"prod<{pda.dtype.name},{pda.ndim}>", + args={"x": pda, "axis": axis_, "skipNan": False}, ) - if axis is None or len(axis_arry) == 0 or pda.ndim == 1: + if axis is None or len(axis_) == 0 or pda.ndim == 1: return create_pdarray(cast(str, repMsg)).flatten()[0] else: return create_pdarray(cast(str, repMsg)) @@ -2881,12 +2873,12 @@ def min( RuntimeError Raised if there's a server-side error thrown """ - axis_arry = _get_axis_pdarray(axis) + axis_ = tuple(axis) if axis is not None else () repMsg = generic_msg( - cmd=f"min<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>", - args={"x": pda, "axis": axis_arry, "skipNan": False}, + cmd=f"min<{pda.dtype.name},{pda.ndim}>", + args={"x": pda, "axis": axis_, "skipNan": False}, ) - if axis is None or len(axis_arry) == 0 or pda.ndim == 1: + if axis is None or len(axis_) == 0 or pda.ndim == 1: return create_pdarray(cast(str, repMsg)).flatten()[0] else: return create_pdarray(cast(str, repMsg)) @@ -2919,12 +2911,12 @@ def max( RuntimeError Raised if there's a server-side error thrown """ - axis_arry = _get_axis_pdarray(axis) + axis_ = tuple(axis) if axis is not None else () repMsg = generic_msg( - cmd=f"max<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>", - args={"x": pda, "axis": axis_arry, "skipNan": False}, + cmd=f"max<{pda.dtype.name},{pda.ndim}>", + args={"x": pda, "axis": axis_, "skipNan": False}, ) - if axis is None or len(axis_arry) == 0 or pda.ndim == 1: + if axis is None or len(axis_) == 0 or pda.ndim == 1: return create_pdarray(cast(str, repMsg)).flatten()[0] else: return create_pdarray(cast(str, repMsg)) diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index 0de6ad69a0..dd7c94ab49 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -129,8 +129,9 @@ module ReductionMsg where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; + type opType = if t == bool then int else t; - if d.rank == 1 then return (+ reduce x:opType); + if d.rank == 1 then return makeDistArray([(+ reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { @@ -151,21 +152,12 @@ module ReductionMsg } @arkouda.registerCommand - proc prod(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] t throws - where (t==int || t==real || t==uint(64)) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(* reduce x)]); - } - - proc prod(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] int throws - where (t==bool) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(* reduce x:int)]); - } - - proc prod(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] throws - where (t==int || t==real || t==uint(64) || t==bool) && (x.rank != 1) && (axis.rank == 1) { + proc prod(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; type opType = if t == bool then int else t; + if d.rank == 1 then return makeDistArray([(* reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { @@ -183,35 +175,15 @@ module ReductionMsg } return ret; } - } - - proc prod(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t==int || t==real || t==uint(64) || t==bool) && (axis.rank != 1) { - throw new Error("prod only accepts axis of rank 1."); } - proc prod(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t!=int && t!=real && t!=uint(64) && t!=bool) { - throw new Error("prod does not support type %s".format(type2str(t))); - } - - @arkouda.registerCommand - proc max(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] t throws - where (t==int || t==real || t==uint(64)) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(max reduce x)]); - } - - proc max(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] int throws - where (t==bool) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(max reduce x:int)]); - } - - proc max(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] throws - where (t==int || t==real || t==uint(64) || t==bool) && (x.rank != 1) && (axis.rank == 1) { + proc max(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; type opType = if t == bool then int else t; + if d.rank == 1 then return makeDistArray([(max reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { @@ -229,34 +201,15 @@ module ReductionMsg } return ret; } - } - - proc max(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t==int || t==real || t==uint(64) || t==bool) && (axis.rank != 1) { - throw new Error("max only accepts axis of rank 1."); - } - - proc max(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t!=int && t!=real && t!=uint(64) && t!=bool) { - throw new Error("max does not support type %s".format(type2str(t))); } @arkouda.registerCommand - proc min(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] t throws - where (t==int || t==real || t==uint(64)) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(min reduce x)]); - } - - proc min(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] int throws - where (t==bool) && (x.rank == 1) && (axis.rank == 1) { - return makeDistArray([(min reduce x:int)]); - } - - proc min(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] throws - where (t==int || t==real || t==uint(64) || t==bool) && (x.rank != 1) && (axis.rank == 1) { + proc min(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; type opType = if t == bool then int else t; + if d.rank == 1 then return makeDistArray([(min reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { @@ -274,16 +227,6 @@ module ReductionMsg } return ret; } - } - - proc min(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t==int || t==real || t==uint(64) || t==bool) && (axis.rank != 1) { - throw new Error("min only accepts axis of rank 1."); - } - - proc min(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [d2] t throws - where (t!=int && t!=real && t!=uint(64) && t!=bool) { - throw new Error("min does not support type %s".format(type2str(t))); } /* From 7c6bb71f668025f1916abaa8075faf1cc4a6acae Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Thu, 17 Oct 2024 15:57:19 -0600 Subject: [PATCH 11/14] fix mypy errors and multi-dim build failure Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 8 ++++---- src/AryUtil.chpl | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 81e97b6e35..4a20f740d3 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -2747,7 +2747,7 @@ def sum( RuntimeError Raised if there's a server-side error thrown """ - axis_ = tuple(axis) if axis is not None else () + axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis)) repMsg = generic_msg( cmd=f"sum<{pda.dtype.name},{pda.ndim}>", args={"x": pda, "axis": axis, "skipNan": False}, @@ -2836,7 +2836,7 @@ def prod(pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Un RuntimeError Raised if there's a server-side error thrown """ - axis_ = tuple(axis) if axis is not None else () + axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis)) repMsg = generic_msg( cmd=f"prod<{pda.dtype.name},{pda.ndim}>", args={"x": pda, "axis": axis_, "skipNan": False}, @@ -2873,7 +2873,7 @@ def min( RuntimeError Raised if there's a server-side error thrown """ - axis_ = tuple(axis) if axis is not None else () + axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis)) repMsg = generic_msg( cmd=f"min<{pda.dtype.name},{pda.ndim}>", args={"x": pda, "axis": axis_, "skipNan": False}, @@ -2911,7 +2911,7 @@ def max( RuntimeError Raised if there's a server-side error thrown """ - axis_ = tuple(axis) if axis is not None else () + axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis)) repMsg = generic_msg( cmd=f"max<{pda.dtype.name},{pda.ndim}>", args={"x": pda, "axis": axis_, "skipNan": False}, diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index 470d5bffd7..ce6b5d968f 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -157,7 +157,8 @@ module AryUtil } proc validateNegativeAxes(axes: list(int), param nd: int): (bool, list(int)) { - return new list(validateNegativeAxes(axes.toArray(), nd)); + const (valid, ret) = validateNegativeAxes(axes.toArray(), nd); + return (valid, new list(ret)); } /* From 322506b8c8cf8eb6d96f09a1cdd7f7fe90cea163 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Thu, 17 Oct 2024 16:09:28 -0600 Subject: [PATCH 12/14] create list-specific implementations of aryUtil helper procedures used in ReductionMsg. Fix 'axis' argument in 'ak.sum' Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 2 +- src/AryUtil.chpl | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 4a20f740d3..18b9d16cf4 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -2750,7 +2750,7 @@ def sum( axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis)) repMsg = generic_msg( cmd=f"sum<{pda.dtype.name},{pda.ndim}>", - args={"x": pda, "axis": axis, "skipNan": False}, + args={"x": pda, "axis": axis_, "skipNan": False}, ) if axis is None or len(axis_) == 0 or pda.ndim == 1: # TODO: remove call to 'flatten' diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index ce6b5d968f..dd5fe3ed6a 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -157,8 +157,17 @@ module AryUtil } proc validateNegativeAxes(axes: list(int), param nd: int): (bool, list(int)) { - const (valid, ret) = validateNegativeAxes(axes.toArray(), nd); - return (valid, new list(ret)); + var ret = new list(int); + for a in axes { + if a >= 0 && a < nd { + ret.pushBack(a); + } else if a < 0 && a >= -nd { + ret.pushBack(nd + a); + } else { + return (false, ret); + } + } + return (true, ret); } /* @@ -334,7 +343,13 @@ module AryUtil } proc reducedShape(shape: ?N*int, axes: list(int)): N*int { - return reducedShape(shape, axes.toArray()); + var ret: N*int; + for param i in 0.. Date: Fri, 18 Oct 2024 08:40:41 -0600 Subject: [PATCH 13/14] fix return type change to reduction procedures Signed-off-by: Jeremiah Corrado --- src/Message.chpl | 5 ++++- src/ReductionMsg.chpl | 19 +++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/Message.chpl b/src/Message.chpl index 7addbc6f77..68e12a509d 100644 --- a/src/Message.chpl +++ b/src/Message.chpl @@ -120,8 +120,11 @@ module Message { proc type MsgTuple.fromScalar(scalar: ?t): MsgTuple throws { import NumPyDType; + const dTypeName = type2str(t); + if dTypeName == "undef" + then throw new Error("Unknown scalar type '%s' in MsgTuple.fromScalar".format(t:string)); return new MsgTuple( - msg = "%s %s".format(type2str(t), NumPyDType.type2fmt(t)).format(scalar), + msg = "%s %s".format(dTypeName, NumPyDType.type2fmt(t)).format(scalar), msgType = MsgType.NORMAL, msgFormat = MsgFormat.STRING, payload = b"" diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index dd7c94ab49..a0f628f375 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -124,13 +124,16 @@ module ReductionMsg } } + proc reductionReturnType(type t) type + do return if t == bool then int else t; + @arkouda.registerCommand - proc sum(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + proc sum(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = if t == bool then int else t; + type opType = reductionReturnType(t); if d.rank == 1 then return makeDistArray([(+ reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); @@ -152,11 +155,11 @@ module ReductionMsg } @arkouda.registerCommand - proc prod(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + proc prod(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = if t == bool then int else t; + type opType = reductionReturnType(t); if d.rank == 1 then return makeDistArray([(* reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); @@ -178,11 +181,11 @@ module ReductionMsg } @arkouda.registerCommand - proc max(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + proc max(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = if t == bool then int else t; + type opType = reductionReturnType(t); if d.rank == 1 then return makeDistArray([(max reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); @@ -204,11 +207,11 @@ module ReductionMsg } @arkouda.registerCommand - proc min(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] throws + proc min(ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = if t == bool then int else t; + type opType = reductionReturnType(t); if d.rank == 1 then return makeDistArray([(min reduce x:opType)]); const (valid, axes) = validateNegativeAxes(axis, x.rank); From 9b18917f63fdc0f011521cdef8392655d6340116 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Fri, 18 Oct 2024 10:54:48 -0600 Subject: [PATCH 14/14] fix reducedShape helper for lists Signed-off-by: Jeremiah Corrado --- src/AryUtil.chpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index dd5fe3ed6a..5ef9642c90 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -345,7 +345,7 @@ module AryUtil proc reducedShape(shape: ?N*int, axes: list(int)): N*int { var ret: N*int; for param i in 0..