diff --git a/PROTO_tests/tests/extrema_test.py b/PROTO_tests/tests/extrema_test.py index ad292ed74b..dbf395cd36 100644 --- a/PROTO_tests/tests/extrema_test.py +++ b/PROTO_tests/tests/extrema_test.py @@ -3,12 +3,14 @@ import arkouda as ak -NO_UINT = ["int64", "float64", "bool"] +NUMERIC_TYPES = ["int64", "uint64", "float64", "bool"] def make_np_arrays(size, dtype): if dtype == "int64": return np.random.randint(-(2**32), 2**32, size=size, dtype=dtype) + elif dtype == "uint64": + return ak.cast(ak.randint(-(2**32), 2**32, size=size), dtype) elif dtype == "float64": return np.random.uniform(-(2**32), 2**32, size=size) elif dtype == "bool": @@ -18,10 +20,9 @@ def make_np_arrays(size, dtype): class TestExtrema: @pytest.mark.parametrize("prob_size", pytest.prob_size) - @pytest.mark.parametrize("dtype", ["int64", "float64"]) + @pytest.mark.parametrize("dtype", ["int64", "uint64", "float64"]) def test_extrema(self, prob_size, dtype): - # TODO add testing for uint once #2695 is completed - pda = ak.randint(-(2**32), 2**32, size=prob_size, dtype=dtype) + pda = ak.array(make_np_arrays(prob_size, dtype)) ak_sorted = ak.sort(pda) K = prob_size // 2 @@ -33,9 +34,8 @@ def test_extrema(self, prob_size, dtype): assert (ak.maxk(pda, K) == ak_sorted[-K:]).all() assert (pda[ak.argmaxk(pda, K)] == ak_sorted[-K:]).all() - @pytest.mark.parametrize("dtype", NO_UINT) + @pytest.mark.parametrize("dtype", NUMERIC_TYPES) def test_argmin_and_argmax(self, dtype): - # TODO add testing for uint once #2695 is completed np_arr = make_np_arrays(1000, dtype) ak_arr = ak.array(np_arr) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index d518691f07..16eb48ef54 100755 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -103,7 +103,7 @@ def unescape(s): mydtype = dtype(dtname) if mydtype == bigint: # we have to strip off quotes prior to 1.32 - if value[0] == "\"": + if value[0] == '"': return int(value[1:-1]) else: return int(value) @@ -832,13 +832,13 @@ def max(self) -> numpy_scalars: """ return max(self) - def argmin(self) -> np.int64: + def argmin(self) -> Union[np.int64, np.uint64]: """ Return the index of the first occurrence of the array min value """ return argmin(self) - def argmax(self) -> np.int64: + def argmax(self) -> Union[np.int64, np.uint64]: """ Return the index of the first occurrence of the array max value. """ @@ -2179,7 +2179,7 @@ def max(pda: pdarray) -> numpy_scalars: @typechecked -def argmin(pda: pdarray) -> np.int64: +def argmin(pda: pdarray) -> Union[np.int64, np.uint64]: """ Return the index of the first occurrence of the array min value. @@ -2190,7 +2190,7 @@ def argmin(pda: pdarray) -> np.int64: Returns ------- - np.int64 + Union[np.int64, np.uint64] The index of the argmin calculated from the pda Raises @@ -2205,7 +2205,7 @@ def argmin(pda: pdarray) -> np.int64: @typechecked -def argmax(pda: pdarray) -> np.int64: +def argmax(pda: pdarray) -> Union[np.int64, np.uint64]: """ Return the index of the first occurrence of the array max value. @@ -2216,7 +2216,7 @@ def argmax(pda: pdarray) -> np.int64: Returns ------- - np.int64 + Union[np.int64, np.uint64] The index of the argmax calculated from the pda Raises diff --git a/src/KExtremeMsg.chpl b/src/KExtremeMsg.chpl index 5fde2f95aa..4607b98e7c 100644 --- a/src/KExtremeMsg.chpl +++ b/src/KExtremeMsg.chpl @@ -48,23 +48,34 @@ module KExtremeMsg select(gEnt.dtype) { when (DType.Int64) { var e = toSymEntry(gEnt,int); - var aV; - - if !returnIndices { - aV = computeExtremaValues(e.a, k); - } else { - aV = computeExtremaIndices(e.a, k); - } - + var aV = if !returnIndices then computeExtremaValues(e.a, k) else computeExtremaIndices(e.a, k); st.addEntry(vname, new shared SymEntry(aV)); repMsg = "created " + st.attrib(vname); keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); return new MsgTuple(repMsg, MsgType.NORMAL); } + when (DType.UInt64) { + var e = toSymEntry(gEnt,uint); + if !returnIndices { + var aV = computeExtremaValues(e.a, k); + st.addEntry(vname, new shared SymEntry(aV)); + + repMsg = "created " + st.attrib(vname); + keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } else { + var aV = computeExtremaIndices(e.a, k); + st.addEntry(vname, new shared SymEntry(aV)); + + repMsg = "created " + st.attrib(vname); + keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + } when (DType.Float64) { + var e = toSymEntry(gEnt,real); if !returnIndices { - var e = toSymEntry(gEnt,real); var aV = computeExtremaValues(e.a, k); st.addEntry(vname, new shared SymEntry(aV)); @@ -72,7 +83,6 @@ module KExtremeMsg keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); return new MsgTuple(repMsg, MsgType.NORMAL); } else { - var e = toSymEntry(gEnt,real); var aV = computeExtremaIndices(e.a, k); st.addEntry(vname, new shared SymEntry(aV)); @@ -109,33 +119,42 @@ module KExtremeMsg select(gEnt.dtype) { when (DType.Int64) { var e = toSymEntry(gEnt,int); - var aV; - if !returnIndices { - aV = computeExtremaValues(e.a, k, false); - } else { - aV = computeExtremaIndices(e.a, k, false); - } - + var aV = if !returnIndices then computeExtremaValues(e.a, k, false) else computeExtremaIndices(e.a, k, false); st.addEntry(vname, new shared SymEntry(aV)); repMsg = "created " + st.attrib(vname); keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.Float64) { + when (DType.UInt64) { + var e = toSymEntry(gEnt,uint); if !returnIndices { - var e = toSymEntry(gEnt,real); var aV = computeExtremaValues(e.a, k, false); - st.addEntry(vname, new shared SymEntry(aV)); repMsg = "created " + st.attrib(vname); keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); return new MsgTuple(repMsg, MsgType.NORMAL); } else { - var e = toSymEntry(gEnt,real); var aV = computeExtremaIndices(e.a, k, false); + st.addEntry(vname, new shared SymEntry(aV)); + repMsg = "created " + st.attrib(vname); + keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + } + when (DType.Float64) { + var e = toSymEntry(gEnt,real); + if !returnIndices { + var aV = computeExtremaValues(e.a, k, false); + st.addEntry(vname, new shared SymEntry(aV)); + + repMsg = "created " + st.attrib(vname); + keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } else { + var aV = computeExtremaIndices(e.a, k, false); st.addEntry(vname, new shared SymEntry(aV)); repMsg = "created " + st.attrib(vname);