From 218496e976f5f47ff87a1e75e4b9068944023429 Mon Sep 17 00:00:00 2001 From: ajpotts Date: Wed, 23 Oct 2024 13:43:25 -0400 Subject: [PATCH] Part 3 of argTypeReductionMessage refactor (#3854) Co-authored-by: Amanda Potts --- src/ReductionMsg.chpl | 91 +------------------------------------------ tests/client_test.py | 27 ------------- 2 files changed, 1 insertion(+), 117 deletions(-) diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index a0f628f375..ea36fed539 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -34,96 +34,7 @@ module ReductionMsg const basicReductionOps = {"sum", "prod", "min", "max"}, boolReductionOps = {"any", "all", "is_sorted", "is_locally_sorted"}, idxReductionOps = {"argmin", "argmax"}; - - /* - Compute an array reduction along one or more axes - (where the result has the same data type as the input array) - - Supports: 'sum', 'prod', 'min', 'max' - */ - - - @arkouda.registerND(cmd_prefix="reduce") - proc argTypeReductionMessage(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { - use SliceReductionOps; - param pn = Reflection.getRoutineName(); - const x = msgArgs.getValueOf("x"), - op = msgArgs.getValueOf("op"), - nAxes = msgArgs.get("nAxes").getIntValue(), - axesRaw = msgArgs.get("axis").toScalarArray(int, nAxes), - skipNan = msgArgs.get("skipNan").getBoolValue(), - rname = st.nextName(); - - var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(x, st); - - if !basicReductionOps.contains(op) { - const errorMsg = notImplementedError(pn,op,gEnt.dtype); - rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - - proc computeReduction(type t): MsgTuple throws { - const eIn = toSymEntry(gEnt, t, nd); - type opType = if t == bool then int else t; - - if nd == 1 || nAxes == 0 { - var s: opType; - select op { - when "sum" do s = sumSlice(eIn.a, eIn.a.domain, opType, skipNan); - when "prod" do s = prodSlice(eIn.a, eIn.a.domain, opType, skipNan); - when "min" do s = getMinSlice(eIn.a, eIn.a.domain, skipNan); - when "max" do s = getMaxSlice(eIn.a, eIn.a.domain, skipNan); - otherwise halt("unreachable"); - } - - const scalarValue = if (t == bool && (op == "min" || op == "max")) - then "bool " + bool2str(if s == 1 then true else false) - else (type2str(opType) + " " + type2fmt(opType)).format(s); - rmLogger.debug(getModuleName(),pn,getLineNumber(),scalarValue); - return new MsgTuple(scalarValue, MsgType.NORMAL); - } else { - const (valid, axes) = validateNegativeAxes(axesRaw, nd); - if !valid { - var errorMsg = "Invalid axis value(s) '%?' in slicing reduction".format(axesRaw); - rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg,MsgType.ERROR); - } else { - const outShape = reducedShape(eIn.a.shape, axes); - var eOut = st.addEntry(rname, outShape, opType); - - forall sliceIdx in domOffAxis(eIn.a.domain, axes) { - const sliceDom = domOnAxis(eIn.a.domain, sliceIdx, axes); - var s: opType; - select op { - when "sum" do s = sumSlice(eIn.a, sliceDom, opType, skipNan); - when "prod" do s = prodSlice(eIn.a, sliceDom, opType, skipNan); - when "min" do s = getMinSlice(eIn.a, sliceDom, skipNan); - when "max" do s = getMaxSlice(eIn.a, sliceDom, skipNan); - otherwise halt("unreachable"); - } - eOut.a[sliceIdx] = s; - } - - const repMsg = "created " + st.attrib(rname); - rmLogger.info(getModuleName(),pn,getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - } - } - - select gEnt.dtype { - when DType.Int64 do return computeReduction(int); - when DType.UInt64 do return computeReduction(uint); - when DType.Float64 do return computeReduction(real); - when DType.Bool do return computeReduction(bool); - otherwise { - var errorMsg = notImplementedError(pn,dtype2str(gEnt.dtype)); - rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg,MsgType.ERROR); - } - } - } - + proc reductionReturnType(type t) type do return if t == bool then int else t; diff --git a/tests/client_test.py b/tests/client_test.py index 267aaab3c8..d92247a20c 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -156,30 +156,3 @@ def test_client_get_server_commands(self): cmds = ak.client.get_server_commands() for cmd in ["connect", "info", "str"]: assert cmd in cmds - - @pytest.mark.skip_if_max_rank_greater_than(9) - def test_client_array_dim_cmd_error(self): - """ - Tests that a user will get a helpful error message if they attempt to - use a multi-dimensional command when the server is not configured to - support multi-dimensional arrays of the given rank. - """ - with pytest.raises(RuntimeError) as cm: - resp = generic_msg("reduce10D") - - err_msg = ( - f"Error: Command 'reduce10D' is not supported with the current server configuration as the maximum array dimensionality is {ak.client.get_max_array_rank()}. " - f"Please recompile with support for at least 10D arrays" - ) - cm.match(err_msg) # Asserts the error msg matches the expected value - - def test_client_nd_unimplemented_error(self): - """ - Tests that a user will get a helpful error message if they attempt to - use a multi-dimensional command when only a 1D implementation exists. - """ - with pytest.raises(RuntimeError) as cm: - resp = generic_msg("connect2D") - - err_msg = "Error: Command 'connect' is not supported for multidimensional arrays" - cm.match(err_msg) # Asserts the error msg matches the expected value