Skip to content

Commit

Permalink
Part 3 of argTypeReductionMessage refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 22, 2024
1 parent 76408db commit 2e37708
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 110 deletions.
84 changes: 1 addition & 83 deletions src/ReductionMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -41,89 +41,7 @@ module ReductionMsg
Supports: 'sum', 'prod', 'min', 'max'
*/


@arkouda.registerND(cmd_prefix="reduce")
proc argTypeReductionMessage(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
use SliceReductionOps;
param pn = Reflection.getRoutineName();
const x = msgArgs.getValueOf("x"),
op = msgArgs.getValueOf("op"),
nAxes = msgArgs.get("nAxes").getIntValue(),
axesRaw = msgArgs.get("axis").toScalarArray(int, nAxes),
skipNan = msgArgs.get("skipNan").getBoolValue(),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(x, st);

if !basicReductionOps.contains(op) {
const errorMsg = notImplementedError(pn,op,gEnt.dtype);
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}

proc computeReduction(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd);
type opType = if t == bool then int else t;

if nd == 1 || nAxes == 0 {
var s: opType;
select op {
when "sum" do s = sumSlice(eIn.a, eIn.a.domain, opType, skipNan);
when "prod" do s = prodSlice(eIn.a, eIn.a.domain, opType, skipNan);
when "min" do s = getMinSlice(eIn.a, eIn.a.domain, skipNan);
when "max" do s = getMaxSlice(eIn.a, eIn.a.domain, skipNan);
otherwise halt("unreachable");
}

const scalarValue = if (t == bool && (op == "min" || op == "max"))
then "bool " + bool2str(if s == 1 then true else false)
else (type2str(opType) + " " + type2fmt(opType)).format(s);
rmLogger.debug(getModuleName(),pn,getLineNumber(),scalarValue);
return new MsgTuple(scalarValue, MsgType.NORMAL);
} else {
const (valid, axes) = validateNegativeAxes(axesRaw, nd);
if !valid {
var errorMsg = "Invalid axis value(s) '%?' in slicing reduction".format(axesRaw);
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
} else {
const outShape = reducedShape(eIn.a.shape, axes);
var eOut = st.addEntry(rname, outShape, opType);

forall sliceIdx in domOffAxis(eIn.a.domain, axes) {
const sliceDom = domOnAxis(eIn.a.domain, sliceIdx, axes);
var s: opType;
select op {
when "sum" do s = sumSlice(eIn.a, sliceDom, opType, skipNan);
when "prod" do s = prodSlice(eIn.a, sliceDom, opType, skipNan);
when "min" do s = getMinSlice(eIn.a, sliceDom, skipNan);
when "max" do s = getMaxSlice(eIn.a, sliceDom, skipNan);
otherwise halt("unreachable");
}
eOut.a[sliceIdx] = s;
}

const repMsg = "created " + st.attrib(rname);
rmLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
}
}

select gEnt.dtype {
when DType.Int64 do return computeReduction(int);
when DType.UInt64 do return computeReduction(uint);
when DType.Float64 do return computeReduction(real);
when DType.Bool do return computeReduction(bool);
otherwise {
var errorMsg = notImplementedError(pn,dtype2str(gEnt.dtype));
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
}
}
}


proc reductionReturnType(type t) type
do return if t == bool then int else t;

Expand Down
27 changes: 0 additions & 27 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,3 @@ def test_client_get_server_commands(self):
cmds = ak.client.get_server_commands()
for cmd in ["connect", "info", "str"]:
assert cmd in cmds

@pytest.mark.skip_if_max_rank_greater_than(9)
def test_client_array_dim_cmd_error(self):
"""
Tests that a user will get a helpful error message if they attempt to
use a multi-dimensional command when the server is not configured to
support multi-dimensional arrays of the given rank.
"""
with pytest.raises(RuntimeError) as cm:
resp = generic_msg("reduce10D")

err_msg = (
f"Error: Command 'reduce10D' is not supported with the current server configuration as the maximum array dimensionality is {ak.client.get_max_array_rank()}. "
f"Please recompile with support for at least 10D arrays"
)
cm.match(err_msg) # Asserts the error msg matches the expected value

def test_client_nd_unimplemented_error(self):
"""
Tests that a user will get a helpful error message if they attempt to
use a multi-dimensional command when only a 1D implementation exists.
"""
with pytest.raises(RuntimeError) as cm:
resp = generic_msg("connect2D")

err_msg = "Error: Command 'connect' is not supported for multidimensional arrays"
cm.match(err_msg) # Asserts the error msg matches the expected value

0 comments on commit 2e37708

Please sign in to comment.