Skip to content

Commit

Permalink
Closes Bears-R-Us#3782: flip function to match numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 3, 2024
1 parent a44dd0f commit e071811
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 1 deletion.
2 changes: 1 addition & 1 deletion arkouda/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Categorical:
BinOps = frozenset(["==", "!="])
RegisterablePieces = frozenset(["categories", "codes", "permutation", "segments", "_akNAcode"])
RequiredPieces = frozenset(["categories", "codes", "_akNAcode"])
permutation = None
permutation: Union[pdarray, None] = None
segments = None
objType = "Categorical"
dtype = akdtype(str_) # this is being set for now because Categoricals only supported on Strings
Expand Down
1 change: 1 addition & 0 deletions arkouda/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@
from arkouda.numpy.rec import *

from ._numeric import *
from ._manipulation_functions import *
88 changes: 88 additions & 0 deletions arkouda/numpy/_manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# from __future__ import annotations

from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast

from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray
from arkouda.pdarrayclass import pdarray
from arkouda.strings import Strings
from arkouda.categorical import Categorical


__all__ = ["flip"]


def flip(
x: Union[pdarray, Strings, Categorical], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None
) -> Union[pdarray, Strings, Categorical]:
"""
Reverse an array's values along a particular axis or axes.
Parameters
----------
x : pdarray, Strings, or Categorical
Reverse the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered.
axis : int or Tuple[int, ...], optional
The axis or axes along which to flip the array. If None, flip the array along all axes.
Returns
-------
pdarray, Strings, or Categorical
An array with the entries of axis reversed.
Note
----
This differs from numpy as it actually reverses the data, rather than presenting a view.
"""
axisList = []
if axis is not None:
axisList = list(axis) if isinstance(axis, tuple) else [axis]

if isinstance(x, pdarray):
try:
return create_pdarray(
cast(
str,
generic_msg(
cmd=(
f"flipAll<{x.dtype},{x.ndim}>"
if axis is None
else f"flip<{x.dtype},{x.ndim}>"
),
args={
"name": x,
"nAxes": len(axisList),
"axis": axisList,
},
),
)
)

except RuntimeError as e:
raise IndexError(f"Failed to flip array: {e}")
elif isinstance(x, Categorical):
if isinstance(x.permutation, pdarray):
return Categorical.from_codes(
codes=flip(x.codes),
categories=x.categories,
permutation=flip(x.permutation),
segments=x.segments,
)
else:
return Categorical.from_codes(
codes=flip(x.codes),
categories=x.categories,
permutation=None,
segments=x.segments,
)

elif isinstance(x, Strings):
rep_msg = generic_msg(
cmd="flipString", args={"objType": x.objType, "obj": x.entry, "size": x.size}
)
return Strings.from_return_msg(cast(str, rep_msg))
else:
raise TypeError("flip only accepts type pdarray, Strings, or Categorical.")
33 changes: 33 additions & 0 deletions src/SegmentedMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module SegmentedMsg {
use Map;
use CTypes;
use IOUtils;
use CommAggregation;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand Down Expand Up @@ -1186,6 +1187,37 @@ module SegmentedMsg {
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc flipStringMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const name = msgArgs.getValueOf("obj");

// check to make sure symbols defined
st.checkTable(name);
var strings = getSegString(name, st);
ref origVals = strings.values.a;
ref offs = strings.offsets.a;
const lengths = strings.getLengths();

ref retOffs = makeDistArray(lengths.domain, int);
forall i in lengths.domain with (var valAgg = newDstAggregator(int)) {
valAgg.copy(retOffs[lengths.domain.high - i], lengths[i]);
}
retOffs = (+ scan retOffs) - retOffs;

var flippedVals = makeDistArray(strings.values.a.domain, uint(8));
forall (off, len, j) in zip(offs, lengths, ..#offs.size) with (var valAgg = newDstAggregator(uint(8))) {
var i = 0;
for b in interpretAsBytes(origVals, off..#len, borrow=true) {
valAgg.copy(flippedVals[retOffs[lengths.domain.high - j] + i], b:uint(8));
i += 1;
}
}

var retString = getSegString(retOffs, flippedVals, st);
var repMsg = "created " + st.attrib(retString.name) + "+created bytes.size %?".format(retString.nBytes);
smLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

use CommandMap;
registerFunction("segmentLengths", segmentLengthsMsg, getModuleName());
registerFunction("caseChange", caseChangeMsg, getModuleName());
Expand All @@ -1210,4 +1242,5 @@ module SegmentedMsg {
registerFunction("segmentedWhere", segmentedWhereMsg, getModuleName());
registerFunction("segmentedFull", segmentedFullMsg, getModuleName());
registerFunction("getSegStringProperty", getSegStringPropertyMsg, getModuleName());
registerFunction("flipString", flipStringMsg, getModuleName());
}
48 changes: 48 additions & 0 deletions tests/manipulation_functions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import pytest

import arkouda as ak
from arkouda.categorical import Categorical
from arkouda.testing import assert_equal

seed = pytest.seed


class TestJoin:

@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", [int, ak.int64, ak.uint64, float, ak.float64, bool, ak.bool_])
def test_flip_pdarray(self, size, dtype):
a = ak.arange(size, dtype=dtype)
f = ak.flip(a)
assert_equal(f, a[::-1])

@pytest.mark.skip_if_max_rank_less_than(3)
@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", [ak.int64, ak.uint64, ak.float64])
def test_flip_multi_dim(self, size, dtype):
a = ak.arange(size * 4, dtype=dtype).reshape((2, 2, size))
f = ak.flip(a)
assert_equal(f, (size * 4 - 1) - a)

@pytest.mark.skip_if_max_rank_less_than(3)
@pytest.mark.parametrize("size", pytest.prob_size)
def test_flip_multi_dim_bool(self, size):
a = ak.arange(size * 4, dtype=bool).reshape((2, 2, size))
f = ak.flip(a)
assert_equal(f, ak.cast((size * 4 - 1) - a, dt=ak.bool_))

@pytest.mark.parametrize("size", pytest.prob_size)
def test_flip_string(self, size):
s = ak.random_strings_uniform(1, 2, size, seed=seed)
assert_equal(ak.flip(s), s[::-1])

@pytest.mark.parametrize("size", pytest.prob_size)
def test_flip_categorical(self, size):
s = ak.random_strings_uniform(1, 2, size, seed=seed)
c = Categorical(s)
assert_equal(ak.flip(c), c[::-1])

# test case when c.permutation = None
c2 = Categorical(c.to_pandas())
assert_equal(ak.flip(c2), c2[::-1])

0 comments on commit e071811

Please sign in to comment.