Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #3855: refactor boolReductionMsg #3876

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ServerModules.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ OperatorMsg
ParquetMsg
RandMsg
ReductionMsg
ReductionMsgFunctions
RegistrationMsg
SegmentedMsg
SequenceMsg
Expand Down
24 changes: 4 additions & 20 deletions arkouda/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ def max(

from arkouda import max as ak_max

arr = Array._new(ak_max(x._array, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_max(x._array, axis=axis, keepdims=keepdims))


# this is a temporary fix to get mean working with XArray
Expand Down Expand Up @@ -138,11 +134,7 @@ def min(

from arkouda import min as ak_min

arr = Array._new(ak_min(x._array, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_min(x._array, axis=axis, keepdims=keepdims))


def prod(
Expand Down Expand Up @@ -180,11 +172,7 @@ def prod(

from arkouda import prod as ak_prod

arr = Array._new(ak_prod(x_op, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_prod(x_op, axis=axis, keepdims=keepdims))


# Not working with XArray yet, pending a fix for:
Expand Down Expand Up @@ -277,11 +265,7 @@ def sum(

from arkouda import sum as ak_sum

arr = Array._new(ak_sum(x_op, axis=axis))
if keepdims or axis is None or x.ndim == 1:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak_sum(x_op, axis=axis, keepdims=keepdims))


# Not working with XArray yet, pending a fix for:
Expand Down
11 changes: 7 additions & 4 deletions arkouda/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from arkouda.numpy import cast as akcast
from arkouda.numpy import where
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import bool_scalars
from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import int_scalars, resolve_scalar_dtype, str_, str_scalars
Expand Down Expand Up @@ -290,7 +291,7 @@ def standardize_categories(cls, arrays, NAvalue="N/A"):
new_categories = concatenate((new_categories, array([NAvalue])))
return [arr.set_categories(new_categories, NAvalue=NAvalue) for arr in arrays]

def equals(self, other) -> bool:
def equals(self, other) -> bool_scalars:
"""
Whether Categoricals are the same size and all entries are equal.

Expand Down Expand Up @@ -320,9 +321,11 @@ def equals(self, other) -> bool:
if other.size != self.size:
return False
else:
return akall(self == other)
else:
return False
result = akall(self == other)
if isinstance(result, (bool, np.bool_)):
return result

return False

def set_categories(self, new_categories, NAvalue=None):
"""
Expand Down
9 changes: 7 additions & 2 deletions arkouda/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from numpy import array as ndarray
from numpy import dtype as npdtype
Expand All @@ -13,6 +14,7 @@
from arkouda.groupbyclass import GroupBy, unique
from arkouda.numpy import cast as akcast
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import bool_scalars
from arkouda.numpy.dtypes import float64 as akfloat64
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.pdarrayclass import RegistrationError, pdarray
Expand Down Expand Up @@ -290,7 +292,7 @@ def from_return_msg(cls, rep_msg):

return cls.factory(idx) if len(idx) > 1 else cls.factory(idx[0])

def equals(self, other: Index) -> bool:
def equals(self, other: Index) -> bool_scalars:
"""
Whether Indexes are the same size, and all entries are equal.

Expand Down Expand Up @@ -351,7 +353,10 @@ def equals(self, other: Index) -> bool:

return True
else:
return akall(self == other)
result = akall(self == other)
if isinstance(result, (bool, np.bool_)):
return result
return False

def memory_usage(self, unit="B"):
"""
Expand Down
2 changes: 1 addition & 1 deletion arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typeguard import typechecked

from arkouda.client import generic_msg
from arkouda.dtypes import str_ as akstr_
from arkouda.numpy.dtypes import str_ as akstr_
from arkouda.groupbyclass import GroupBy, groupable
from arkouda.numpy.dtypes import DTypes, bigint
from arkouda.numpy.dtypes import bool_ as ak_bool
Expand Down
7 changes: 7 additions & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"int8",
"intTypes",
"int_scalars",
"isSupportedBool",
"isSupportedFloat",
"isSupportedInt",
"isSupportedNumber",
Expand Down Expand Up @@ -245,6 +246,8 @@ def __repr__(self) -> str:
return self.value


ARKOUDA_SUPPORTED_BOOLS = (bool, np.bool_)

ARKOUDA_SUPPORTED_INTS = (
int,
np.int8,
Expand Down Expand Up @@ -313,6 +316,10 @@ def isSupportedNumber(num):
return isinstance(num, ARKOUDA_SUPPORTED_NUMBERS)


def isSupportedBool(num):
return isinstance(num, ARKOUDA_SUPPORTED_BOOLS)


def resolve_scalar_dtype(val: object) -> str:
"""
Try to infer what dtype arkouda_server should treat val as.
Expand Down
Loading
Loading