Skip to content

Commit

Permalink
Closes Bears-R-Us#3782: flip function to match numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Sep 26, 2024
1 parent 8dae0c5 commit 17fdd76
Show file tree
Hide file tree
Showing 5 changed files with 7,024 additions and 56 deletions.
1 change: 1 addition & 0 deletions arkouda/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@
from arkouda.numpy.rec import *

from ._numeric import *
from ._manipulation_functions import *
80 changes: 80 additions & 0 deletions arkouda/numpy/_manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# from __future__ import annotations

from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast

from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray
from arkouda.pdarrayclass import pdarray
from arkouda.strings import Strings
from arkouda.categorical import Categorical


__all__ = ["flip"]


def flip(
x: Union[pdarray, Strings, Categorical], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None
) -> Union[pdarray, Strings, Categorical]:
"""
Reverse an array's values along a particular axis or axes.
Parameters
----------
x : pdarray, Strings, or Categorical
Reverse the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered.
axis : int or Tuple[int, ...], optional
The axis or axes along which to flip the array. If None, flip the array along all axes.
Returns
-------
pdarray, Strings, or Categorical
An array with the entries of axis reversed.
Note
----
This differs from numpy as it actually reverses the data, rather than presenting a view.
"""
axisList = []
if axis is not None:
axisList = list(axis) if isinstance(axis, tuple) else [axis]

if isinstance(x, pdarray):
try:
return create_pdarray(
cast(
str,
generic_msg(
cmd=(
f"flipAll<{x.dtype},{x.ndim}>"
if axis is None
else f"flip<{x.dtype},{x.ndim}>"
),
args={
"name": x,
"nAxes": len(axisList),
"axis": axisList,
},
),
)
)

except RuntimeError as e:
raise IndexError(f"Failed to flip array: {e}")
elif isinstance(x, Categorical):
return Categorical.from_codes(
codes=flip(x.codes),
categories=x.categories,
permutation=flip(x.permutation),
segments=x.segments,
)
elif isinstance(x, Strings):
# return x[::-1]
rep_msg = generic_msg(
cmd="flipString", args={"objType": x.objType, "obj": x.entry, "size": x.size}
)
return Strings.from_return_msg(cast(str, rep_msg))
else:
raise TypeError("flip only accepts type pdarray, Strings, or Categorical.")
49 changes: 49 additions & 0 deletions src/SegmentedMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module SegmentedMsg {
use Map;
use CTypes;
use IOUtils;
use CommAggregation;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand Down Expand Up @@ -1186,6 +1187,53 @@ module SegmentedMsg {
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc flipStringMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
var pn = Reflection.getRoutineName();
var repMsg: string;
const objtype = msgArgs.getValueOf("objType").toUpper(): ObjType;
const name = msgArgs.getValueOf("obj");
// const size = msgArgs['size'].toScalar(int);

// check to make sure symbols defined
st.checkTable(name);
var strings = getSegString(name, st);

var flippedVals: [strings.values.a.domain] uint(8) = makeDistArray(strings.values.a.domain, uint(8));
const lengths = strings.getLengths();

ref origVals = strings.values.a;
ref offs = strings.offsets.a;
// const size = origVals.size;

// ref lengths_high = lengths.domain.high;

ref retOffs = makeDistArray(lengths.domain, int);
forall i in lengths.domain with (var valAgg = newDstAggregator(int)) {
valAgg.copy(retOffs[lengths.domain.high - i], lengths[i]);
}
retOffs = (+ scan retOffs) - retOffs;

ref retOffsReversed = makeDistArray(lengths.domain, int);
forall i in lengths.domain with (var valAgg = newDstAggregator(int)) {
valAgg.copy(retOffsReversed[lengths.domain.high - i], retOffs[i]);
}

forall (off, off2, len) in zip(offs, retOffsReversed, lengths) with (var valAgg = newSrcAggregator(uint(8))) {
var i = 0;
for b in interpretAsBytes(origVals, off..#len, borrow=true) {
valAgg.copy(flippedVals[off2 + i], b:uint(8));
i += 1;
}
}

var retString = getSegString(retOffs, flippedVals, st);
repMsg = "created " + st.attrib(retString.name) + "+created bytes.size %?".format(retString.nBytes);

smLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}


use CommandMap;
registerFunction("segmentLengths", segmentLengthsMsg, getModuleName());
registerFunction("caseChange", caseChangeMsg, getModuleName());
Expand All @@ -1210,4 +1258,5 @@ module SegmentedMsg {
registerFunction("segmentedWhere", segmentedWhereMsg, getModuleName());
registerFunction("segmentedFull", segmentedFullMsg, getModuleName());
registerFunction("getSegStringProperty", getSegStringPropertyMsg, getModuleName());
registerFunction("flipString", flipStringMsg, getModuleName());
}
Loading

0 comments on commit 17fdd76

Please sign in to comment.