diff --git a/arkouda/__init__.py b/arkouda/__init__.py index a6659bf99e..c87dbc1581 100644 --- a/arkouda/__init__.py +++ b/arkouda/__init__.py @@ -9,6 +9,7 @@ from arkouda.client import * from arkouda.client_dtypes import * from arkouda.dtypes import * +from arkouda.decorators import * from arkouda.pdarrayclass import * from arkouda.sorting import * from arkouda.pdarraysetops import * diff --git a/arkouda/categorical.py b/arkouda/categorical.py index ed90ee7860..4d9f799450 100644 --- a/arkouda/categorical.py +++ b/arkouda/categorical.py @@ -16,10 +16,11 @@ import numpy as np # type: ignore from typeguard import typechecked +from arkouda.decorators import objtypedec from arkouda.dtypes import bool as akbool from arkouda.dtypes import int64 as akint64 -from arkouda.dtypes import int_scalars, resolve_scalar_dtype, str_scalars +from arkouda.dtypes import int_scalars, resolve_scalar_dtype, str_scalars, str_ from arkouda.groupbyclass import GroupBy, unique from arkouda.infoclass import information, list_registry from arkouda.logger import getArkoudaLogger @@ -39,6 +40,7 @@ __all__ = ["Categorical"] +@objtypedec class Categorical: """ Represents an array of values belonging to named categories. Converting a @@ -77,7 +79,6 @@ class Categorical: BinOps = frozenset(["==", "!="]) RegisterablePieces = frozenset(["categories", "codes", "permutation", "segments", "_akNAcode"]) RequiredPieces = frozenset(["categories", "codes", "_akNAcode"]) - objtype = "category" permutation = None segments = None @@ -137,8 +138,13 @@ def __init__(self, values, **kwargs) -> None: self.nlevels = self.categories.size self.ndim = self.codes.ndim self.shape = self.codes.shape + self.dtype = str_ self.name: Optional[str] = None + @property + def objtype(self): + return self.objtype + @classmethod @typechecked def from_codes( diff --git a/arkouda/decorators.py b/arkouda/decorators.py new file mode 100644 index 0000000000..08a9cfec12 --- /dev/null +++ b/arkouda/decorators.py @@ -0,0 +1,3 @@ +def objtypedec(orig_cls): + orig_cls.objtype = orig_cls.__name__ + return orig_cls diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index fd8eca09d9..7b3101a6b4 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -170,7 +170,7 @@ module AryUtil thisSize = g.size; hasStr = true; } - when "category" { + when "Categorical" { // passed only Categorical.codes.name to be sorted on var g = getGenericTypedArrayEntry(name, st); thisSize = g.size; diff --git a/src/UniqueMsg.chpl b/src/UniqueMsg.chpl index e1110ff21f..fa85c03c27 100644 --- a/src/UniqueMsg.chpl +++ b/src/UniqueMsg.chpl @@ -118,7 +118,7 @@ module UniqueMsg for (name, objtype, i) in zip(names, types, 0..) { var newName = st.nextName(); select objtype { - when "pdarray", "category" { + when "pdarray", "Categorical" { var g = getGenericTypedArrayEntry(name, st); // Gathers unique values, stores in SymTab, and returns repMsg chunk proc gatherHelper(type t) throws { @@ -239,7 +239,7 @@ module UniqueMsg } for (name, objtype, i) in zip(names, types, 0..) { select objtype { - when "pdarray", "category" { + when "pdarray", "Categorical" { var g = getGenericTypedArrayEntry(name, st); select g.dtype { when DType.Int64 { diff --git a/tests/categorical_test.py b/tests/categorical_test.py index 1a73b85170..5a83834ce9 100644 --- a/tests/categorical_test.py +++ b/tests/categorical_test.py @@ -64,7 +64,7 @@ def testBaseCategorical(self): ).all() ) self.assertEqual(10, cat.size) - self.assertEqual("category", cat.objtype) + self.assertEqual("Categorical", cat.objtype) with self.assertRaises(ValueError): ak.Categorical(ak.arange(0, 5, 10)) @@ -229,7 +229,7 @@ def testConcatenate(self): catTwo = self._getCategorical("string-two", 51) resultCat = catOne.concatenate([catTwo]) - self.assertEqual("category", resultCat.objtype) + self.assertEqual("Categorical", resultCat.objtype) self.assertIsInstance(resultCat, ak.Categorical) self.assertEqual(100, resultCat.size) @@ -239,7 +239,7 @@ def testConcatenate(self): self.assertFalse(resultCat.segments) resultCat = ak.concatenate([catOne, catOne], ordered=False) - self.assertEqual("category", resultCat.objtype) + self.assertEqual("Categorical", resultCat.objtype) self.assertIsInstance(resultCat, ak.Categorical) self.assertEqual(100, resultCat.size)