From a3619027653f45fce4e4e3776546608580e83edc Mon Sep 17 00:00:00 2001 From: Amanda Potts Date: Fri, 15 Nov 2024 15:12:17 -0500 Subject: [PATCH] Closes #3300: shape function --- arkouda/numpy/__init__.py | 1 + arkouda/numpy/_utils.py | 51 ++++++++++++++++++++++++++++++++++ arkouda/numpy/dtypes/dtypes.py | 5 ++++ tests/numpy/utils_test.py | 23 +++++++++++++++ 4 files changed, 80 insertions(+) create mode 100644 arkouda/numpy/_utils.py create mode 100644 tests/numpy/utils_test.py diff --git a/arkouda/numpy/__init__.py b/arkouda/numpy/__init__.py index 5cb6627da7..3438745160 100644 --- a/arkouda/numpy/__init__.py +++ b/arkouda/numpy/__init__.py @@ -96,4 +96,5 @@ from arkouda.numpy.rec import * from ._numeric import * +from ._utils import * from ._manipulation_functions import * diff --git a/arkouda/numpy/_utils.py b/arkouda/numpy/_utils.py new file mode 100644 index 0000000000..07a9cc0b27 --- /dev/null +++ b/arkouda/numpy/_utils.py @@ -0,0 +1,51 @@ +from typing import Iterable, Tuple, Union + +from numpy import ndarray + +from arkouda.numpy.dtypes import all_scalars, isSupportedDType +from arkouda.pdarrayclass import pdarray +from arkouda.strings import Strings + +__all__ = ["shape"] + + +def shape(a: Union[pdarray, Strings, all_scalars]) -> Tuple: + """ + Return the shape of an array. + + Parameters + ---------- + a : pdarray + Input array. + + Returns + ------- + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + Examples + -------- + >>> import arkouda as ak + >>> ak.shape(ak.eye(3,2)) + (3, 2) + >>> ak.shape([[1, 3]]) + (1, 2) + >>> ak.shape([0]) + (1,) + >>> ak.shape(0) + () + + """ + if isinstance(a, (pdarray, Strings, ndarray, Iterable)) and not isinstance(a, str): + try: + result = a.shape + except AttributeError: + from arkouda import array + + result = array(a).shape + return result + elif isSupportedDType(a): + return () + else: + raise TypeError("shape requires type pdarray, ndarray, Iterable, or numeric scalar.") diff --git a/arkouda/numpy/dtypes/dtypes.py b/arkouda/numpy/dtypes/dtypes.py index ef4c1c2819..02609c2232 100644 --- a/arkouda/numpy/dtypes/dtypes.py +++ b/arkouda/numpy/dtypes/dtypes.py @@ -57,6 +57,7 @@ "intTypes", "int_scalars", "isSupportedBool", + "isSupportedDType", "isSupportedFloat", "isSupportedInt", "isSupportedNumber", @@ -320,6 +321,10 @@ def isSupportedBool(num): return isinstance(num, ARKOUDA_SUPPORTED_BOOLS) +def isSupportedDType(scalar): + return isinstance(scalar, ARKOUDA_SUPPORTED_DTYPES) + + def resolve_scalar_dtype(val: object) -> str: """ Try to infer what dtype arkouda_server should treat val as. diff --git a/tests/numpy/utils_test.py b/tests/numpy/utils_test.py new file mode 100644 index 0000000000..b10599022e --- /dev/null +++ b/tests/numpy/utils_test.py @@ -0,0 +1,23 @@ +import pytest + +import arkouda as ak +import numpy as np + + +class TestFromNumericFunctions: + + @pytest.mark.parametrize("x", [0, [0], [1, 2, 3], np.ndarray([0, 1, 2]), [[1, 3]], np.eye(3, 2)]) + def test_shape(self, x): + assert ak.shape(x) == np.shape(x) + + def test_shape_pdarray(self): + a = ak.arange(5) + assert ak.shape(a) == np.shape(a.to_ndarray()) + + def test_shape_strings(self): + a = ak.array(["a", "b", "c"]) + assert ak.shape(a) == np.shape(a.to_ndarray()) + + @pytest.mark.skip_if_max_rank_less_than(2) + def test_shape_multidim_pdarray(self): + assert ak.shape(ak.eye(3, 2)) == (3, 2)