From f5389f9040ab575f5974ac010190e318072d5878 Mon Sep 17 00:00:00 2001 From: drculhane Date: Wed, 23 Oct 2024 11:48:06 -0400 Subject: [PATCH 1/2] Revised and ready for review --- arkouda/numpy/_numeric.py | 56 +++++----- arkouda/numpy/dtypes/dtypes.py | 6 + src/EfuncMsg.chpl | 199 +++++++++++++++------------------ 3 files changed, 122 insertions(+), 139 deletions(-) diff --git a/arkouda/numpy/_numeric.py b/arkouda/numpy/_numeric.py index 1c61c9dcd2..14458ff8e8 100644 --- a/arkouda/numpy/_numeric.py +++ b/arkouda/numpy/_numeric.py @@ -15,7 +15,7 @@ from arkouda.numpy.dtypes import dtype as akdtype from arkouda.numpy.dtypes import float64 as ak_float64 from arkouda.numpy.dtypes import int64 as ak_int64 -from arkouda.numpy.dtypes import int64 as akint64 +from arkouda.numpy.dtypes import uint64 as ak_uint64 from arkouda.numpy.dtypes import ( int_scalars, isSupportedNumber, @@ -23,7 +23,7 @@ resolve_scalar_dtype, str_, ) -from arkouda.numpy.dtypes import uint64 as ak_uint64 +from arkouda.numpy.dtypes import _datatype_check from arkouda.pdarrayclass import all as ak_all from arkouda.pdarrayclass import any as ak_any from arkouda.pdarrayclass import argmax, create_pdarray, pdarray, sum @@ -103,6 +103,17 @@ class ErrorMode(Enum): return_validity = "return_validity" +# TODO: standardize error checking in python interface + +# merge_where comes in handy in arctan2 and some other functions. + + +def _merge_where(new_pda, where, ret): + new_pda = cast(new_pda, ret.dtype) + new_pda[where] = ret + return new_pda + + @typechecked def cast( pda: Union[pdarray, Strings, Categorical], # type: ignore @@ -1296,18 +1307,16 @@ def _trig_helper(pda: pdarray, func: str, where: Union[bool, pdarray] = True) -> Raises ------ TypeError - Raised if the parameter is not a pdarray - TypeError - Raised if where condition is not type Boolean + Raised if pda is not a pdarray or if is not real or int or uint, or if where is not Boolean """ + _datatype_check(pda.dtype, [ak_float64, ak_int64, ak_uint64], func) if where is True: repMsg = type_cast( str, generic_msg( - cmd=f"efunc{pda.ndim}D", + cmd=f"{func}<{pda.dtype},{pda.ndim}>", args={ - "func": func, - "array": pda, + "x": pda, }, ), ) @@ -1320,18 +1329,13 @@ def _trig_helper(pda: pdarray, func: str, where: Union[bool, pdarray] = True) -> repMsg = type_cast( str, generic_msg( - cmd=f"efunc{pda.ndim}D", + cmd=f"{func}<{pda.dtype},{pda.ndim}>", args={ - "func": func, - "array": pda[where], + "x": pda[where], }, ), ) - new_pda = pda[:] - ret = create_pdarray(repMsg) - new_pda = cast(new_pda, ret.dtype) - new_pda[where] = ret - return new_pda + return _merge_where(pda[:], where, create_pdarray(repMsg)) @typechecked @@ -1363,11 +1367,7 @@ def rad2deg(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray: elif where is False: return pda else: - new_pda = pda - ret = 180 * (pda[where] / np.pi) - new_pda = cast(new_pda, ret.dtype) - new_pda[where] = ret - return new_pda + return _merge_where(pda[:], where, 180*(pda[where]/np.pi)) @typechecked @@ -1399,11 +1399,7 @@ def deg2rad(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray: elif where is False: return pda else: - new_pda = pda - ret = np.pi * pda[where] / 180 - new_pda = cast(new_pda, ret.dtype) - new_pda[where] = ret - return new_pda + return _merge_where(pda[:], where, (np.pi*pda[where]/180)) def _hash_helper(a): @@ -2292,10 +2288,8 @@ def array_equal(pda_a: pdarray, pda_b: pdarray, equal_nan: bool = False): def putmask( - A : pdarray , - mask : pdarray, - Values : pdarray - ) : # doesn't return anything, as A is overwritten in place + A: pdarray, mask: pdarray, Values: pdarray +): # doesn't return anything, as A is overwritten in place """ Overwrites elements of A with elements from B based upon a mask array. Similar to numpy.putmask, where mask = False, A retains its original value, @@ -2363,7 +2357,7 @@ def putmask( return -def eye(rows: int_scalars, cols: int_scalars, diag: int_scalars = 0, dt: type = akint64): +def eye(rows: int_scalars, cols: int_scalars, diag: int_scalars = 0, dt: type = ak_int64): """ Return a pdarray with zeros everywhere except along a diagonal, which is all ones. The matrix need not be square. diff --git a/arkouda/numpy/dtypes/dtypes.py b/arkouda/numpy/dtypes/dtypes.py index faa94aadbc..429e61cbd1 100644 --- a/arkouda/numpy/dtypes/dtypes.py +++ b/arkouda/numpy/dtypes/dtypes.py @@ -25,6 +25,7 @@ ) __all__ = [ + "_datatype_check", "ARKOUDA_SUPPORTED_DTYPES", "DType", "DTypeObjects", @@ -84,6 +85,11 @@ } +def _datatype_check(the_dtype, allowed_list, name): + if not (the_dtype in allowed_list): + raise TypeError(f"{name} only implements types {allowed_list}") + + def dtype(x): # we had to create our own bigint type since numpy # gives them dtype=object there's no np equivalent diff --git a/src/EfuncMsg.chpl b/src/EfuncMsg.chpl index dde99962c9..44e8f56c19 100644 --- a/src/EfuncMsg.chpl +++ b/src/EfuncMsg.chpl @@ -42,6 +42,97 @@ module EfuncMsg :throws: `UndefinedSymbolError(name)` */ + +// This section is a rewrite of trig and hyp functions in new interface. +// This comment will be updated as other functions are rewritten, and deleted +// once the rewrite is complete. + + @arkouda.registerCommand (name="sin") + proc sine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return sin(x); } + + proc sine (x : [?d] ?t) : [d] real throws + { throw new Error ("sin does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="cos") + proc cosine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return cos(x); } + + proc cosine (x : [?d] ?t) : [d] real throws + { throw new Error ("cos does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="tan") + proc tangent (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return tan(x); } + + proc tangent (x : [?d] ?t) : [d] real throws + { throw new Error ("tan does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="arcsin") + proc arcsine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return asin(x); } + + proc arcsine (x : [?d] ?t) : [d] real throws + { throw new Error ("arcsin does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="arccos") + proc arccosine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return acos(x); } + + proc arccosine (x : [?d] ?t) : [d] real throws + { throw new Error ("arccos does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="arctan") + proc arctangent (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return atan(x); } + + proc arctangent (x : [?d] ?t) : [d] real throws + { throw new Error ("arctan does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="sinh") + proc hypsine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return sinh(x); } + + proc hypsine (x : [?d] ?t) : [d] real throws + { throw new Error ("sinh does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="cosh") + proc hypcosine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return cosh(x); } + + proc hypcosine (x : [?d] ?t) : [d] real throws + { throw new Error ("cosh does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="tanh") + proc hyptangent (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return tanh(x); } + + proc hyptangent (x : [?d] ?t) : [d] real throws + { throw new Error ("tanh does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="arcsinh") + proc archypsine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return asinh(x); } + + proc archypsine (x : [?d] ?t) : [d] real throws + { throw new Error ("arcsinh does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="arccosh") + proc archypcosine (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return acosh(x); } + + proc archypcosine (x : [?d] ?t) : [d] real throws + { throw new Error ("arccosh does not support type %s".format(type2str(t))) ; } + + @arkouda.registerCommand (name="arctanh") + proc archyptangent (x : [?d] ?t) : [d] real throws + where (t==int || t==real || t==uint) { return atanh(x); } + + proc archyptangent (x : [?d] ?t) : [d] real throws + { throw new Error ("arctanh does not support type %s".format(type2str(t))) ; } + +// End of rewrite section -- delete this comment after all of EfuncMsg is rewritten. + @arkouda.registerND proc efuncMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { param pn = Reflection.getRoutineName(); @@ -101,42 +192,6 @@ module EfuncMsg return new MsgTuple(errorMsg, MsgType.ERROR); } } - when "sin" { - st.addEntry(rname, new shared SymEntry(sin(ea))); - } - when "cos" { - st.addEntry(rname, new shared SymEntry(cos(ea))); - } - when "tan" { - st.addEntry(rname, new shared SymEntry(tan(ea))); - } - when "arcsin" { - st.addEntry(rname, new shared SymEntry(asin(ea))); - } - when "arccos" { - st.addEntry(rname, new shared SymEntry(acos(ea))); - } - when "arctan" { - st.addEntry(rname, new shared SymEntry(atan(ea))); - } - when "sinh" { - st.addEntry(rname, new shared SymEntry(sinh(ea))); - } - when "cosh" { - st.addEntry(rname, new shared SymEntry(cosh(ea))); - } - when "tanh" { - st.addEntry(rname, new shared SymEntry(tanh(ea))); - } - when "arcsinh" { - st.addEntry(rname, new shared SymEntry(asinh(ea))); - } - when "arccosh" { - st.addEntry(rname, new shared SymEntry(acosh(ea))); - } - when "arctanh" { - st.addEntry(rname, new shared SymEntry(atanh(ea))); - } when "hash64" { overMemLimit(numBytes(int) * e.size); var a = st.addEntry(rname, e.tupShape, uint); @@ -253,42 +308,6 @@ module EfuncMsg return new MsgTuple(errorMsg, MsgType.ERROR); } } - when "sin" { - st.addEntry(rname, new shared SymEntry(sin(ea))); - } - when "cos" { - st.addEntry(rname, new shared SymEntry(cos(ea))); - } - when "tan" { - st.addEntry(rname, new shared SymEntry(tan(ea))); - } - when "arcsin" { - st.addEntry(rname, new shared SymEntry(asin(ea))); - } - when "arccos" { - st.addEntry(rname, new shared SymEntry(acos(ea))); - } - when "arctan" { - st.addEntry(rname, new shared SymEntry(atan(ea))); - } - when "sinh" { - st.addEntry(rname, new shared SymEntry(sinh(ea))); - } - when "cosh" { - st.addEntry(rname, new shared SymEntry(cosh(ea))); - } - when "tanh" { - st.addEntry(rname, new shared SymEntry(tanh(ea))); - } - when "arcsinh" { - st.addEntry(rname, new shared SymEntry(asinh(ea))); - } - when "arccosh" { - st.addEntry(rname, new shared SymEntry(acosh(ea))); - } - when "arctanh" { - st.addEntry(rname, new shared SymEntry(atanh(ea))); - } when "hash64" { overMemLimit(numBytes(real) * e.size); var a = st.addEntry(rname, e.tupShape, uint); @@ -397,42 +416,6 @@ module EfuncMsg return new MsgTuple(errorMsg, MsgType.ERROR); } } - when "sin" { - st.addEntry(rname, new shared SymEntry(sin(ea))); - } - when "cos" { - st.addEntry(rname, new shared SymEntry(cos(ea))); - } - when "tan" { - st.addEntry(rname, new shared SymEntry(tan(ea))); - } - when "arcsin" { - st.addEntry(rname, new shared SymEntry(asin(ea))); - } - when "arccos" { - st.addEntry(rname, new shared SymEntry(acos(ea))); - } - when "arctan" { - st.addEntry(rname, new shared SymEntry(atan(ea))); - } - when "sinh" { - st.addEntry(rname, new shared SymEntry(sinh(ea))); - } - when "cosh" { - st.addEntry(rname, new shared SymEntry(cosh(ea))); - } - when "tanh" { - st.addEntry(rname, new shared SymEntry(tanh(ea))); - } - when "arcsinh" { - st.addEntry(rname, new shared SymEntry(asinh(ea))); - } - when "arccosh" { - st.addEntry(rname, new shared SymEntry(acosh(ea))); - } - when "arctanh" { - st.addEntry(rname, new shared SymEntry(atanh(ea))); - } when "parity" { st.addEntry(rname, new shared SymEntry(parity(ea))); } From 207da7fdd6ea4b1f821f89db7bc8e0f252ca2e28 Mon Sep 17 00:00:00 2001 From: drculhane Date: Wed, 23 Oct 2024 12:02:56 -0400 Subject: [PATCH 2/2] also addresses issue 3860 --- tests/array_api/array_manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/array_api/array_manipulation.py b/tests/array_api/array_manipulation.py index 76b20e267b..7dc3c57937 100644 --- a/tests/array_api/array_manipulation.py +++ b/tests/array_api/array_manipulation.py @@ -303,7 +303,7 @@ def test_stack_unstack(self): assert bp.tolist() == b.tolist() assert cp.tolist() == c.tolist() - @pytest.mark.skip_if_max_rank_less_than(2) + @pytest.mark.skip_if_max_rank_less_than(3) def test_tile(self): a = randArr((2, 3))