Skip to content

Commit

Permalink
Closes #3855: refactor boolReductionMsg
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Nov 4, 2024
1 parent 2bb26b1 commit d0b8307
Show file tree
Hide file tree
Showing 26 changed files with 654 additions and 457 deletions.
1 change: 1 addition & 0 deletions ServerModules.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ OperatorMsg
ParquetMsg
RandMsg
ReductionMsg
ReductionMsgFunctions
RegistrationMsg
SegmentedMsg
SequenceMsg
Expand Down
24 changes: 4 additions & 20 deletions arkouda/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ def max(

from arkouda import max as ak_max

arr = Array._new(ak_max(x._array, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_max(x._array, axis=axis, keepdims=keepdims))


# this is a temporary fix to get mean working with XArray
Expand Down Expand Up @@ -138,11 +134,7 @@ def min(

from arkouda import min as ak_min

arr = Array._new(ak_min(x._array, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_min(x._array, axis=axis, keepdims=keepdims))


def prod(
Expand Down Expand Up @@ -180,11 +172,7 @@ def prod(

from arkouda import prod as ak_prod

arr = Array._new(ak_prod(x_op, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_prod(x_op, axis=axis, keepdims=keepdims))


# Not working with XArray yet, pending a fix for:
Expand Down Expand Up @@ -277,11 +265,7 @@ def sum(

from arkouda import sum as ak_sum

arr = Array._new(ak_sum(x_op, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_sum(x_op, axis=axis, keepdims=keepdims))


# Not working with XArray yet, pending a fix for:
Expand Down
11 changes: 7 additions & 4 deletions arkouda/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from arkouda.numpy import cast as akcast
from arkouda.numpy import where
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import bool_scalars
from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import int_scalars, resolve_scalar_dtype, str_, str_scalars
Expand Down Expand Up @@ -290,7 +291,7 @@ def standardize_categories(cls, arrays, NAvalue="N/A"):
new_categories = concatenate((new_categories, array([NAvalue])))
return [arr.set_categories(new_categories, NAvalue=NAvalue) for arr in arrays]

def equals(self, other) -> bool:
def equals(self, other) -> bool_scalars:
"""
Whether Categoricals are the same size and all entries are equal.
Expand Down Expand Up @@ -320,9 +321,11 @@ def equals(self, other) -> bool:
if other.size != self.size:
return False
else:
return akall(self == other)
else:
return False
result = akall(self == other)
if isinstance(result, (bool, np.bool_)):
return result

return False

def set_categories(self, new_categories, NAvalue=None):
"""
Expand Down
9 changes: 7 additions & 2 deletions arkouda/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from numpy import array as ndarray
from numpy import dtype as npdtype
Expand All @@ -13,6 +14,7 @@
from arkouda.groupbyclass import GroupBy, unique
from arkouda.numpy import cast as akcast
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import bool_scalars
from arkouda.numpy.dtypes import float64 as akfloat64
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.pdarrayclass import RegistrationError, pdarray
Expand Down Expand Up @@ -290,7 +292,7 @@ def from_return_msg(cls, rep_msg):

return cls.factory(idx) if len(idx) > 1 else cls.factory(idx[0])

def equals(self, other: Index) -> bool:
def equals(self, other: Index) -> bool_scalars:
"""
Whether Indexes are the same size, and all entries are equal.
Expand Down Expand Up @@ -351,7 +353,10 @@ def equals(self, other: Index) -> bool:

return True
else:
return akall(self == other)
result = akall(self == other)
if isinstance(result, (bool, np.bool_)):
return result
return False

def memory_usage(self, unit="B"):
"""
Expand Down
2 changes: 1 addition & 1 deletion arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typeguard import typechecked

from arkouda.client import generic_msg
from arkouda.dtypes import str_ as akstr_
from arkouda.numpy.dtypes import str_ as akstr_
from arkouda.groupbyclass import GroupBy, groupable
from arkouda.numpy.dtypes import DTypes, bigint
from arkouda.numpy.dtypes import bool_ as ak_bool
Expand Down
7 changes: 7 additions & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"int8",
"intTypes",
"int_scalars",
"isSupportedBool",
"isSupportedFloat",
"isSupportedInt",
"isSupportedNumber",
Expand Down Expand Up @@ -245,6 +246,8 @@ def __repr__(self) -> str:
return self.value


ARKOUDA_SUPPORTED_BOOLS = (bool, np.bool_)

ARKOUDA_SUPPORTED_INTS = (
int,
np.int8,
Expand Down Expand Up @@ -313,6 +316,10 @@ def isSupportedNumber(num):
return isinstance(num, ARKOUDA_SUPPORTED_NUMBERS)


def isSupportedBool(num):
return isinstance(num, ARKOUDA_SUPPORTED_BOOLS)


def resolve_scalar_dtype(val: object) -> str:
"""
Try to infer what dtype arkouda_server should treat val as.
Expand Down
Loading

0 comments on commit d0b8307

Please sign in to comment.