Skip to content

Commit

Permalink
Closes 3809 moves trig and hyp fns to new interface (Bears-R-Us#3863)
Browse files Browse the repository at this point in the history
* Revised and ready for review

* also addresses issue 3860

---------

Co-authored-by: drculhane <[email protected]>
  • Loading branch information
drculhane and drculhane authored Oct 24, 2024
1 parent 2b71be1 commit e35b204
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 140 deletions.
56 changes: 25 additions & 31 deletions arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import float64 as ak_float64
from arkouda.numpy.dtypes import int64 as ak_int64
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import uint64 as ak_uint64
from arkouda.numpy.dtypes import (
int_scalars,
isSupportedNumber,
numeric_scalars,
resolve_scalar_dtype,
str_,
)
from arkouda.numpy.dtypes import uint64 as ak_uint64
from arkouda.numpy.dtypes import _datatype_check
from arkouda.pdarrayclass import all as ak_all
from arkouda.pdarrayclass import any as ak_any
from arkouda.pdarrayclass import argmax, create_pdarray, pdarray, sum
Expand Down Expand Up @@ -103,6 +103,17 @@ class ErrorMode(Enum):
return_validity = "return_validity"


# TODO: standardize error checking in python interface

# merge_where comes in handy in arctan2 and some other functions.


def _merge_where(new_pda, where, ret):
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda


@typechecked
def cast(
pda: Union[pdarray, Strings, Categorical], # type: ignore
Expand Down Expand Up @@ -1296,18 +1307,16 @@ def _trig_helper(pda: pdarray, func: str, where: Union[bool, pdarray] = True) ->
Raises
------
TypeError
Raised if the parameter is not a pdarray
TypeError
Raised if where condition is not type Boolean
Raised if pda is not a pdarray or if is not real or int or uint, or if where is not Boolean
"""
_datatype_check(pda.dtype, [ak_float64, ak_int64, ak_uint64], func)
if where is True:
repMsg = type_cast(
str,
generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"{func}<{pda.dtype},{pda.ndim}>",
args={
"func": func,
"array": pda,
"x": pda,
},
),
)
Expand All @@ -1320,18 +1329,13 @@ def _trig_helper(pda: pdarray, func: str, where: Union[bool, pdarray] = True) ->
repMsg = type_cast(
str,
generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"{func}<{pda.dtype},{pda.ndim}>",
args={
"func": func,
"array": pda[where],
"x": pda[where],
},
),
)
new_pda = pda[:]
ret = create_pdarray(repMsg)
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda
return _merge_where(pda[:], where, create_pdarray(repMsg))


@typechecked
Expand Down Expand Up @@ -1363,11 +1367,7 @@ def rad2deg(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
new_pda = pda
ret = 180 * (pda[where] / np.pi)
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda
return _merge_where(pda[:], where, 180*(pda[where]/np.pi))


@typechecked
Expand Down Expand Up @@ -1399,11 +1399,7 @@ def deg2rad(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
new_pda = pda
ret = np.pi * pda[where] / 180
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda
return _merge_where(pda[:], where, (np.pi*pda[where]/180))


def _hash_helper(a):
Expand Down Expand Up @@ -2292,10 +2288,8 @@ def array_equal(pda_a: pdarray, pda_b: pdarray, equal_nan: bool = False):


def putmask(
A : pdarray ,
mask : pdarray,
Values : pdarray
) : # doesn't return anything, as A is overwritten in place
A: pdarray, mask: pdarray, Values: pdarray
): # doesn't return anything, as A is overwritten in place
"""
Overwrites elements of A with elements from B based upon a mask array.
Similar to numpy.putmask, where mask = False, A retains its original value,
Expand Down Expand Up @@ -2363,7 +2357,7 @@ def putmask(
return


def eye(rows: int_scalars, cols: int_scalars, diag: int_scalars = 0, dt: type = akint64):
def eye(rows: int_scalars, cols: int_scalars, diag: int_scalars = 0, dt: type = ak_int64):
"""
Return a pdarray with zeros everywhere except along a diagonal, which is all ones.
The matrix need not be square.
Expand Down
6 changes: 6 additions & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

__all__ = [
"_datatype_check",
"ARKOUDA_SUPPORTED_DTYPES",
"DType",
"DTypeObjects",
Expand Down Expand Up @@ -84,6 +85,11 @@
}


def _datatype_check(the_dtype, allowed_list, name):
if not (the_dtype in allowed_list):
raise TypeError(f"{name} only implements types {allowed_list}")


def dtype(x):
# we had to create our own bigint type since numpy
# gives them dtype=object there's no np equivalent
Expand Down
199 changes: 91 additions & 108 deletions src/EfuncMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,97 @@ module EfuncMsg
:throws: `UndefinedSymbolError(name)`
*/


// This section is a rewrite of trig and hyp functions in new interface.
// This comment will be updated as other functions are rewritten, and deleted
// once the rewrite is complete.

@arkouda.registerCommand (name="sin")
proc sine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return sin(x); }

proc sine (x : [?d] ?t) : [d] real throws
{ throw new Error ("sin does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="cos")
proc cosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return cos(x); }

proc cosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("cos does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="tan")
proc tangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return tan(x); }

proc tangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("tan does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arcsin")
proc arcsine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return asin(x); }

proc arcsine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arcsin does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arccos")
proc arccosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return acos(x); }

proc arccosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arccos does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arctan")
proc arctangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return atan(x); }

proc arctangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("arctan does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="sinh")
proc hypsine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return sinh(x); }

proc hypsine (x : [?d] ?t) : [d] real throws
{ throw new Error ("sinh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="cosh")
proc hypcosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return cosh(x); }

proc hypcosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("cosh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="tanh")
proc hyptangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return tanh(x); }

proc hyptangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("tanh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arcsinh")
proc archypsine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return asinh(x); }

proc archypsine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arcsinh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arccosh")
proc archypcosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return acosh(x); }

proc archypcosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arccosh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arctanh")
proc archyptangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return atanh(x); }

proc archyptangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("arctanh does not support type %s".format(type2str(t))) ; }

// End of rewrite section -- delete this comment after all of EfuncMsg is rewritten.

@arkouda.registerND
proc efuncMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
Expand Down Expand Up @@ -101,42 +192,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "sin" {
st.addEntry(rname, new shared SymEntry(sin(ea)));
}
when "cos" {
st.addEntry(rname, new shared SymEntry(cos(ea)));
}
when "tan" {
st.addEntry(rname, new shared SymEntry(tan(ea)));
}
when "arcsin" {
st.addEntry(rname, new shared SymEntry(asin(ea)));
}
when "arccos" {
st.addEntry(rname, new shared SymEntry(acos(ea)));
}
when "arctan" {
st.addEntry(rname, new shared SymEntry(atan(ea)));
}
when "sinh" {
st.addEntry(rname, new shared SymEntry(sinh(ea)));
}
when "cosh" {
st.addEntry(rname, new shared SymEntry(cosh(ea)));
}
when "tanh" {
st.addEntry(rname, new shared SymEntry(tanh(ea)));
}
when "arcsinh" {
st.addEntry(rname, new shared SymEntry(asinh(ea)));
}
when "arccosh" {
st.addEntry(rname, new shared SymEntry(acosh(ea)));
}
when "arctanh" {
st.addEntry(rname, new shared SymEntry(atanh(ea)));
}
when "hash64" {
overMemLimit(numBytes(int) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
Expand Down Expand Up @@ -253,42 +308,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "sin" {
st.addEntry(rname, new shared SymEntry(sin(ea)));
}
when "cos" {
st.addEntry(rname, new shared SymEntry(cos(ea)));
}
when "tan" {
st.addEntry(rname, new shared SymEntry(tan(ea)));
}
when "arcsin" {
st.addEntry(rname, new shared SymEntry(asin(ea)));
}
when "arccos" {
st.addEntry(rname, new shared SymEntry(acos(ea)));
}
when "arctan" {
st.addEntry(rname, new shared SymEntry(atan(ea)));
}
when "sinh" {
st.addEntry(rname, new shared SymEntry(sinh(ea)));
}
when "cosh" {
st.addEntry(rname, new shared SymEntry(cosh(ea)));
}
when "tanh" {
st.addEntry(rname, new shared SymEntry(tanh(ea)));
}
when "arcsinh" {
st.addEntry(rname, new shared SymEntry(asinh(ea)));
}
when "arccosh" {
st.addEntry(rname, new shared SymEntry(acosh(ea)));
}
when "arctanh" {
st.addEntry(rname, new shared SymEntry(atanh(ea)));
}
when "hash64" {
overMemLimit(numBytes(real) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
Expand Down Expand Up @@ -397,42 +416,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "sin" {
st.addEntry(rname, new shared SymEntry(sin(ea)));
}
when "cos" {
st.addEntry(rname, new shared SymEntry(cos(ea)));
}
when "tan" {
st.addEntry(rname, new shared SymEntry(tan(ea)));
}
when "arcsin" {
st.addEntry(rname, new shared SymEntry(asin(ea)));
}
when "arccos" {
st.addEntry(rname, new shared SymEntry(acos(ea)));
}
when "arctan" {
st.addEntry(rname, new shared SymEntry(atan(ea)));
}
when "sinh" {
st.addEntry(rname, new shared SymEntry(sinh(ea)));
}
when "cosh" {
st.addEntry(rname, new shared SymEntry(cosh(ea)));
}
when "tanh" {
st.addEntry(rname, new shared SymEntry(tanh(ea)));
}
when "arcsinh" {
st.addEntry(rname, new shared SymEntry(asinh(ea)));
}
when "arccosh" {
st.addEntry(rname, new shared SymEntry(acosh(ea)));
}
when "arctanh" {
st.addEntry(rname, new shared SymEntry(atanh(ea)));
}
when "parity" {
st.addEntry(rname, new shared SymEntry(parity(ea)));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/array_api/array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_stack_unstack(self):
assert bp.tolist() == b.tolist()
assert cp.tolist() == c.tolist()

@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_tile(self):
a = randArr((2, 3))

Expand Down

0 comments on commit e35b204

Please sign in to comment.