Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes 3809 moves trig and hyp fns to new interface #3863

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import float64 as ak_float64
from arkouda.numpy.dtypes import int64 as ak_int64
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import uint64 as ak_uint64
from arkouda.numpy.dtypes import (
int_scalars,
isSupportedNumber,
numeric_scalars,
resolve_scalar_dtype,
str_,
)
from arkouda.numpy.dtypes import uint64 as ak_uint64
from arkouda.numpy.dtypes import _datatype_check
from arkouda.pdarrayclass import all as ak_all
from arkouda.pdarrayclass import any as ak_any
from arkouda.pdarrayclass import argmax, create_pdarray, pdarray, sum
Expand Down Expand Up @@ -103,6 +103,17 @@ class ErrorMode(Enum):
return_validity = "return_validity"


# TODO: standardize error checking in python interface

# merge_where comes in handy in arctan2 and some other functions.


def _merge_where(new_pda, where, ret):
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda


@typechecked
def cast(
pda: Union[pdarray, Strings, Categorical], # type: ignore
Expand Down Expand Up @@ -1296,18 +1307,16 @@ def _trig_helper(pda: pdarray, func: str, where: Union[bool, pdarray] = True) ->
Raises
------
TypeError
Raised if the parameter is not a pdarray
TypeError
Raised if where condition is not type Boolean
Raised if pda is not a pdarray or if is not real or int or uint, or if where is not Boolean
"""
_datatype_check(pda.dtype, [ak_float64, ak_int64, ak_uint64], func)
if where is True:
repMsg = type_cast(
str,
generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"{func}<{pda.dtype},{pda.ndim}>",
args={
"func": func,
"array": pda,
"x": pda,
},
),
)
Expand All @@ -1320,18 +1329,13 @@ def _trig_helper(pda: pdarray, func: str, where: Union[bool, pdarray] = True) ->
repMsg = type_cast(
str,
generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"{func}<{pda.dtype},{pda.ndim}>",
args={
"func": func,
"array": pda[where],
"x": pda[where],
},
),
)
new_pda = pda[:]
ret = create_pdarray(repMsg)
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda
return _merge_where(pda[:], where, create_pdarray(repMsg))


@typechecked
Expand Down Expand Up @@ -1363,11 +1367,7 @@ def rad2deg(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
new_pda = pda
ret = 180 * (pda[where] / np.pi)
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda
return _merge_where(pda[:], where, 180*(pda[where]/np.pi))


@typechecked
Expand Down Expand Up @@ -1399,11 +1399,7 @@ def deg2rad(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
new_pda = pda
ret = np.pi * pda[where] / 180
new_pda = cast(new_pda, ret.dtype)
new_pda[where] = ret
return new_pda
return _merge_where(pda[:], where, (np.pi*pda[where]/180))


def _hash_helper(a):
Expand Down Expand Up @@ -2292,10 +2288,8 @@ def array_equal(pda_a: pdarray, pda_b: pdarray, equal_nan: bool = False):


def putmask(
A : pdarray ,
mask : pdarray,
Values : pdarray
) : # doesn't return anything, as A is overwritten in place
A: pdarray, mask: pdarray, Values: pdarray
): # doesn't return anything, as A is overwritten in place
"""
Overwrites elements of A with elements from B based upon a mask array.
Similar to numpy.putmask, where mask = False, A retains its original value,
Expand Down Expand Up @@ -2363,7 +2357,7 @@ def putmask(
return


def eye(rows: int_scalars, cols: int_scalars, diag: int_scalars = 0, dt: type = akint64):
def eye(rows: int_scalars, cols: int_scalars, diag: int_scalars = 0, dt: type = ak_int64):
"""
Return a pdarray with zeros everywhere except along a diagonal, which is all ones.
The matrix need not be square.
Expand Down
6 changes: 6 additions & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

__all__ = [
"_datatype_check",
"ARKOUDA_SUPPORTED_DTYPES",
"DType",
"DTypeObjects",
Expand Down Expand Up @@ -84,6 +85,11 @@
}


def _datatype_check(the_dtype, allowed_list, name):
if not (the_dtype in allowed_list):
raise TypeError(f"{name} only implements types {allowed_list}")


def dtype(x):
# we had to create our own bigint type since numpy
# gives them dtype=object there's no np equivalent
Expand Down
199 changes: 91 additions & 108 deletions src/EfuncMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,97 @@ module EfuncMsg
:throws: `UndefinedSymbolError(name)`
*/


// This section is a rewrite of trig and hyp functions in new interface.
// This comment will be updated as other functions are rewritten, and deleted
// once the rewrite is complete.

@arkouda.registerCommand (name="sin")
proc sine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return sin(x); }

proc sine (x : [?d] ?t) : [d] real throws
{ throw new Error ("sin does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="cos")
proc cosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return cos(x); }

proc cosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("cos does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="tan")
proc tangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return tan(x); }

proc tangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("tan does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arcsin")
proc arcsine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return asin(x); }

proc arcsine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arcsin does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arccos")
proc arccosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return acos(x); }

proc arccosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arccos does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arctan")
proc arctangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return atan(x); }

proc arctangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("arctan does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="sinh")
proc hypsine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return sinh(x); }

proc hypsine (x : [?d] ?t) : [d] real throws
{ throw new Error ("sinh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="cosh")
proc hypcosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return cosh(x); }

proc hypcosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("cosh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="tanh")
proc hyptangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return tanh(x); }

proc hyptangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("tanh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arcsinh")
proc archypsine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return asinh(x); }

proc archypsine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arcsinh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arccosh")
proc archypcosine (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return acosh(x); }

proc archypcosine (x : [?d] ?t) : [d] real throws
{ throw new Error ("arccosh does not support type %s".format(type2str(t))) ; }

@arkouda.registerCommand (name="arctanh")
proc archyptangent (x : [?d] ?t) : [d] real throws
where (t==int || t==real || t==uint) { return atanh(x); }

proc archyptangent (x : [?d] ?t) : [d] real throws
{ throw new Error ("arctanh does not support type %s".format(type2str(t))) ; }

// End of rewrite section -- delete this comment after all of EfuncMsg is rewritten.

@arkouda.registerND
proc efuncMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
Expand Down Expand Up @@ -101,42 +192,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "sin" {
st.addEntry(rname, new shared SymEntry(sin(ea)));
}
when "cos" {
st.addEntry(rname, new shared SymEntry(cos(ea)));
}
when "tan" {
st.addEntry(rname, new shared SymEntry(tan(ea)));
}
when "arcsin" {
st.addEntry(rname, new shared SymEntry(asin(ea)));
}
when "arccos" {
st.addEntry(rname, new shared SymEntry(acos(ea)));
}
when "arctan" {
st.addEntry(rname, new shared SymEntry(atan(ea)));
}
when "sinh" {
st.addEntry(rname, new shared SymEntry(sinh(ea)));
}
when "cosh" {
st.addEntry(rname, new shared SymEntry(cosh(ea)));
}
when "tanh" {
st.addEntry(rname, new shared SymEntry(tanh(ea)));
}
when "arcsinh" {
st.addEntry(rname, new shared SymEntry(asinh(ea)));
}
when "arccosh" {
st.addEntry(rname, new shared SymEntry(acosh(ea)));
}
when "arctanh" {
st.addEntry(rname, new shared SymEntry(atanh(ea)));
}
when "hash64" {
overMemLimit(numBytes(int) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
Expand Down Expand Up @@ -253,42 +308,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "sin" {
st.addEntry(rname, new shared SymEntry(sin(ea)));
}
when "cos" {
st.addEntry(rname, new shared SymEntry(cos(ea)));
}
when "tan" {
st.addEntry(rname, new shared SymEntry(tan(ea)));
}
when "arcsin" {
st.addEntry(rname, new shared SymEntry(asin(ea)));
}
when "arccos" {
st.addEntry(rname, new shared SymEntry(acos(ea)));
}
when "arctan" {
st.addEntry(rname, new shared SymEntry(atan(ea)));
}
when "sinh" {
st.addEntry(rname, new shared SymEntry(sinh(ea)));
}
when "cosh" {
st.addEntry(rname, new shared SymEntry(cosh(ea)));
}
when "tanh" {
st.addEntry(rname, new shared SymEntry(tanh(ea)));
}
when "arcsinh" {
st.addEntry(rname, new shared SymEntry(asinh(ea)));
}
when "arccosh" {
st.addEntry(rname, new shared SymEntry(acosh(ea)));
}
when "arctanh" {
st.addEntry(rname, new shared SymEntry(atanh(ea)));
}
when "hash64" {
overMemLimit(numBytes(real) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
Expand Down Expand Up @@ -397,42 +416,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "sin" {
st.addEntry(rname, new shared SymEntry(sin(ea)));
}
when "cos" {
st.addEntry(rname, new shared SymEntry(cos(ea)));
}
when "tan" {
st.addEntry(rname, new shared SymEntry(tan(ea)));
}
when "arcsin" {
st.addEntry(rname, new shared SymEntry(asin(ea)));
}
when "arccos" {
st.addEntry(rname, new shared SymEntry(acos(ea)));
}
when "arctan" {
st.addEntry(rname, new shared SymEntry(atan(ea)));
}
when "sinh" {
st.addEntry(rname, new shared SymEntry(sinh(ea)));
}
when "cosh" {
st.addEntry(rname, new shared SymEntry(cosh(ea)));
}
when "tanh" {
st.addEntry(rname, new shared SymEntry(tanh(ea)));
}
when "arcsinh" {
st.addEntry(rname, new shared SymEntry(asinh(ea)));
}
when "arccosh" {
st.addEntry(rname, new shared SymEntry(acosh(ea)));
}
when "arctanh" {
st.addEntry(rname, new shared SymEntry(atanh(ea)));
}
when "parity" {
st.addEntry(rname, new shared SymEntry(parity(ea)));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/array_api/array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_stack_unstack(self):
assert bp.tolist() == b.tolist()
assert cp.tolist() == c.tolist()

@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_tile(self):
a = randArr((2, 3))

Expand Down
Loading