Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes 3878 - refactors rounding functions to new interface, pulls hash function into their own procs #3898

Merged
merged 11 commits into from
Nov 20, 2024
51 changes: 22 additions & 29 deletions arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, List, Sequence, Tuple, TypeVar, Union
from typing import cast as type_cast
from typing import no_type_check

import numpy as np
from typeguard import typechecked

Expand All @@ -26,13 +25,7 @@
from arkouda.numpy.dtypes import _datatype_check
from arkouda.pdarrayclass import all as ak_all
from arkouda.pdarrayclass import any as ak_any
from arkouda.pdarrayclass import (
argmax,
broadcast_if_needed,
create_pdarray,
pdarray,
sum,
)
from arkouda.pdarrayclass import argmax, broadcast_if_needed, create_pdarray, pdarray, sum
from arkouda.pdarraycreation import array, linspace, scalar_array
from arkouda.sorting import sort
from arkouda.strings import Strings
Expand Down Expand Up @@ -281,10 +274,10 @@ def ceil(pda: pdarray) -> pdarray:
>>> ak.ceil(ak.linspace(1.1,5.5,5))
array([2, 3, 4, 5, 6])
"""
_datatype_check(pda.dtype, [float], 'ceil')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"ceil<{pda.dtype},{pda.ndim}>",
args={
"func": "ceil",
"array": pda,
},
)
Expand Down Expand Up @@ -315,11 +308,11 @@ def floor(pda: pdarray) -> pdarray:
>>> ak.floor(ak.linspace(1.1,5.5,5))
array([1, 2, 3, 4, 5])
"""
_datatype_check(pda.dtype, [float], 'floor')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"floor<{pda.dtype},{pda.ndim}>",
args={
"func": "floor",
"array": pda,
"pda": pda,
},
)
return create_pdarray(type_cast(str, repMsg))
Expand Down Expand Up @@ -349,11 +342,11 @@ def round(pda: pdarray) -> pdarray:
>>> ak.round(ak.array([1.1, 2.5, 3.14159]))
array([1, 3, 3])
"""
_datatype_check(pda.dtype, [float], 'round')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"round<{pda.dtype},{pda.ndim}>",
args={
"func": "round",
"array": pda,
"pda": pda,
},
)
return create_pdarray(type_cast(str, repMsg))
Expand Down Expand Up @@ -383,10 +376,10 @@ def trunc(pda: pdarray) -> pdarray:
>>> ak.trunc(ak.array([1.1, 2.5, 3.14159]))
array([1, 2, 3])
"""
_datatype_check(pda.dtype, [float], 'trunc')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"trunc<{pda.dtype},{pda.ndim}>",
args={
"func": "trunc",
"array": pda,
},
)
Expand Down Expand Up @@ -1362,7 +1355,7 @@ def rad2deg(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
return _merge_where(pda[:], where, 180 * (pda[where] / np.pi))
return _merge_where(pda[:], where, 180*(pda[where]/np.pi))


@typechecked
Expand Down Expand Up @@ -1394,7 +1387,7 @@ def deg2rad(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
return _merge_where(pda[:], where, (np.pi * pda[where] / 180))
return _merge_where(pda[:], where, (np.pi*pda[where]/180))


def _hash_helper(a):
Expand Down Expand Up @@ -1521,13 +1514,14 @@ def hash(
def _hash_single(pda: pdarray, full: bool = True):
if pda.dtype == bigint:
return hash(pda.bigint_to_uint_arrays())
_datatype_check(pda.dtype, [float, int, ak_uint64], 'hash')
hname = "hash128" if full else "hash64"
repMsg = type_cast(
str,
generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"{hname}<{pda.dtype},{pda.ndim}>",
args={
"func": "hash128" if full else "hash64",
"array": pda,
"x": pda,
},
),
)
Expand Down Expand Up @@ -2588,19 +2582,18 @@ def matmul(pdaLeft: pdarray, pdaRight: pdarray):
"""
if pdaLeft.ndim != pdaRight.ndim:
raise ValueError("matmul requires matrices of matching rank.")

cmd = f"matmul<{pdaLeft.dtype},{pdaRight.dtype},{pdaLeft.ndim}>"
args = {
"x1": pdaLeft,
"x2": pdaRight,
}
repMsg = generic_msg(
cmd=cmd,
args=args,
return create_pdarray(
generic_msg(
cmd=cmd,
args=args,
)
)

return create_pdarray(repMsg)


def vecdot(x1: pdarray, x2: pdarray):
"""
Expand Down
128 changes: 53 additions & 75 deletions src/EfuncMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,59 @@ module EfuncMsg
@arkouda.registerCommand(name="isfinite")
proc isfinite_ (pda : [?d] real) : [d] bool { return (isFinite(pda)) ; }

@arkouda.registerCommand (name="floor")
proc floor_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return floor(pda); }

@arkouda.registerCommand (name="ceil")
proc ceil_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return ceil(pda); }

@arkouda.registerCommand (name="round")
proc round_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return round(pda); }

@arkouda.registerCommand (name="trunc")
proc trunc_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return trunc(pda); }

// Hashes are more of a challenge to unhook from the old interface, but they
// have been pulled out into their own functions.

@arkouda.instantiateAndRegister
proc hash64 (cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int) : MsgTuple throws
where ((array_dtype==real || array_dtype==int || array_dtype==uint) && array_nd==1) {
const efunc = msgArgs.getValueOf("x"),
e = st[msgArgs["x"]]: SymEntry(array_dtype,array_nd);
const rname = st.nextName();
overMemLimit(numBytes(array_dtype)*e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip (a.a, e.a) {
ai = sipHash64(x) : uint ;
}
var repMsg = "created " + st.attrib(rname);
eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

@arkouda.instantiateAndRegister
proc hash128 (cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int) : MsgTuple throws
where ((array_dtype==real || array_dtype==int || array_dtype==uint) && array_nd==1) {
const efunc = msgArgs.getValueOf("x"),
e = st[msgArgs["x"]]: SymEntry(array_dtype,array_nd);
const rname = st.nextName();
var rname2 = st.nextName();
overMemLimit(numBytes(array_dtype) * e.size * 2);
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
var repMsg = "created " + st.attrib(rname2) + "+";
repMsg += "created " + st.attrib(rname);
eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

// End of rewrite section -- delete this comment after all of EfuncMsg is rewritten.

Expand All @@ -162,9 +215,6 @@ module EfuncMsg
ref ea = e.a;
select efunc
{
when "round" {
st.addEntry(rname, new shared SymEntry(ea));
}
when "sgn" {
st.addEntry(rname, new shared SymEntry(sgn(ea)));
}
Expand All @@ -190,25 +240,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "hash64" {
overMemLimit(numBytes(int) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip(a.a, e.a) {
ai = sipHash64(x): uint;
}
}
when "hash128" {
overMemLimit(numBytes(int) * e.size * 2);
var rname2 = st.nextName();
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
// Put first array's attrib in repMsg and let common
// code append second array's attrib
repMsg += "created " + st.attrib(rname2) + "+";
}
when "popcount" {
st.addEntry(rname, new shared SymEntry(popCount(ea)));
}
Expand Down Expand Up @@ -236,18 +267,6 @@ module EfuncMsg
ref ea = e.a;
select efunc
{
when "ceil" {
st.addEntry(rname, new shared SymEntry(ceil(ea)));
}
when "floor" {
st.addEntry(rname, new shared SymEntry(floor(ea)));
}
when "round" {
st.addEntry(rname, new shared SymEntry(round(ea)));
}
when "trunc" {
st.addEntry(rname, new shared SymEntry(trunc(ea)));
}
when "sgn" {
st.addEntry(rname, new shared SymEntry(sgn(ea)));
}
Expand All @@ -273,25 +292,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "hash64" {
overMemLimit(numBytes(real) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip(a.a, e.a) {
ai = sipHash64(x): uint;
}
}
when "hash128" {
overMemLimit(numBytes(real) * e.size * 2);
var rname2 = st.nextName();
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
// Put first array's attrib in repMsg and let common
// code append second array's attrib
repMsg += "created " + st.attrib(rname2) + "+";
}
otherwise {
var errorMsg = notImplementedError(pn,efunc,gEnt.dtype);
eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -353,9 +353,6 @@ module EfuncMsg
when "ctz" {
st.addEntry(rname, new shared SymEntry(ctz(ea)));
}
when "round" {
st.addEntry(rname, new shared SymEntry(ea));
}
when "sgn" {
st.addEntry(rname, new shared SymEntry(sgn(ea)));
}
Expand Down Expand Up @@ -384,25 +381,6 @@ module EfuncMsg
when "parity" {
st.addEntry(rname, new shared SymEntry(parity(ea)));
}
when "hash64" {
overMemLimit(numBytes(uint) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip(a.a, e.a) {
ai = sipHash64(x): uint;
}
}
when "hash128" {
overMemLimit(numBytes(uint) * e.size * 2);
var rname2 = st.nextName();
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
// Put first array's attrib in repMsg and let common
// code append second array's attrib
repMsg += "created " + st.attrib(rname2) + "+";
}
when "not" {
st.addEntry(rname, new shared SymEntry(!e.a));
}
Expand Down
Loading
Loading