Skip to content

Commit

Permalink
Array API Set functions (Bears-R-Us#3070)
Browse files Browse the repository at this point in the history
* implemnt array api set functions

Signed-off-by: Jeremiah Corrado <[email protected]>

* add SetMsg to Array API server config file

Signed-off-by: Jeremiah Corrado <[email protected]>

* improve unflatten helper in setMsg

Signed-off-by: Jeremiah Corrado <[email protected]>

---------

Signed-off-by: Jeremiah Corrado <[email protected]>
  • Loading branch information
jeremiah-corrado authored Apr 3, 2024
1 parent a84a35f commit 180b7b8
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 36 deletions.
20 changes: 2 additions & 18 deletions ServerModulesArrayApi.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,21 @@
# a module from a build

ArgSortMsg
ArraySetopsMsg
BroadcastMsg
CastMsg
ConcatenateMsg
CSVMsg
DataFrameIndexingMsg
EfuncMsg
EncodingMsg
FlattenMsg
HashMsg
HDF5Msg
HistogramMsg
In1dMsg
IndexingMsg
JoinEqWithDTMsg
KExtremeMsg
# LinalgMsg
LinalgMsg
LogMsg
ManipulationMsg
OperatorMsg
ParquetMsg
RandMsg
ReductionMsg
RegistrationMsg
SegmentedMsg
SequenceMsg
SetMsg
SortMsg
StatsMsg
TimeClassMsg
TransferMsg
UniqueMsg

# Add additional modules located outside
# of the Arkouda src/ directory below.
Expand Down
70 changes: 58 additions & 12 deletions arkouda/array_api/_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from ._array_object import Array

from typing import NamedTuple
from typing import NamedTuple, cast

import arkouda as ak
from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray


class UniqueAllResult(NamedTuple):
Expand All @@ -25,22 +26,67 @@ class UniqueInverseResult(NamedTuple):


def unique_all(x: Array, /) -> UniqueAllResult:
raise ValueError("unique_all not implemented")
resp = cast(
str,
generic_msg(
cmd=f"uniqueAll{x.ndim}D",
args={"name": x._array},
),
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split('+')]

return UniqueAllResult(
values=arrays[0],
indices=arrays[1],
inverse_indices=arrays[2],
counts=arrays[3],
)


def unique_counts(x: Array, /) -> UniqueCountsResult:
raise ValueError("unique_counts not implemented")
resp = cast(
str,
generic_msg(
cmd=f"uniqueCounts{x.ndim}D",
args={"name": x._array},
),
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split('+')]

return UniqueCountsResult(
values=arrays[0],
counts=arrays[1],
)


def unique_inverse(x: Array, /) -> UniqueInverseResult:
raise ValueError("unique_inverse not implemented")
resp = cast(
str,
generic_msg(
cmd=f"uniqueInverse{x.ndim}D",
args={"name": x._array},
),
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split('+')]

return UniqueInverseResult(
values=arrays[0],
inverse_indices=arrays[1],
)

def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.

See its docstring for more information.
"""
res = ak.unique(x._array)
return Array._new(res)
def unique_values(x: Array, /) -> Array:
return Array._new(
create_pdarray(
cast(
str,
generic_msg(
cmd=f"uniqueValues{x.ndim}D",
args={"name": x._array},
),
)
)
)
2 changes: 1 addition & 1 deletion src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ module AryUtil
*/
proc broadcastShape(sa: ?Na*int, sb: ?Nb*int, param N: int): N*int throws {
var s: N*int;
for param i in 0..<N by -1 do {
for param i in 0..<N by -1 {
const n1 = Na - N + i,
n2 = Nb - N + i,
d1 = if n1 < 0 then 1 else sa[n1],
Expand Down
233 changes: 233 additions & 0 deletions src/SetMsg.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
module SetMsg {
use Message;
use MultiTypeSymbolTable;
use MultiTypeSymEntry;
use ServerConfig;
use Logging;
use ServerErrorStrings;
use ServerErrors;
use AryUtil;
use CommAggregation;
use RadixSortLSD;
use Unique;

use ArkoudaAryUtilCompat;

use Reflection;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
const sLogger = new Logger(logLevel, logChannel);

@arkouda.registerND
proc uniqueValuesMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name"),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const eUnique = uniqueFromSorted(eSorted, needCounts=false);

st.addEntry(rname, createSymEntry(eUnique));

const repMsg = "created " + st.attrib(rname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}

@arkouda.registerND
proc uniqueCountsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name"),
uname = st.nextName(),
cname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const (eUnique, eCounts) = uniqueFromSorted(eSorted);

st.addEntry(uname, createSymEntry(eUnique));
st.addEntry(cname, createSymEntry(eCounts));

const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(cname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}

@arkouda.registerND
proc uniqueInverseMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name"),
uname = st.nextName(),
iname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, _, inv) = uniqueSortWithInverse(eFlat);
st.addEntry(uname, createSymEntry(eUnique));
st.addEntry(iname, createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape)));

const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(iname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}

@arkouda.registerND
proc uniqueAllMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name"),
rnames = for 0..<4 do st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true);
st.addEntry(rnames[0], createSymEntry(eUnique));
st.addEntry(rnames[1], createSymEntry(eIndices));
st.addEntry(rnames[2], createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape)));
st.addEntry(rnames[3], createSymEntry(eCounts));

const repMsg = try! "+".join([rn in rnames] "created " + st.attrib(rn));
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}

// TODO: put this in AryUtil or some other common module after merging with #3056
private proc unflatten(const ref aFlat: [?d] ?t, shape: ?N*int): [] t throws {
var unflat = makeDistArray((...shape), t);
const lastRank = unflat.domain.dim(N-1);

// iterate over each slice of the output array along the last dimension
// and copy the data from the corresponding slice of the flat array
forall idx in domOffAxis(unflat.domain, N-1) with (const ord = new orderer(unflat.domain.shape)) {
var idxTup: (N-1)*int;
for i in 0..<(N-1) do idxTup[i] = idx[i];
const rrSlice = ((...idxTup), lastRank);

const low = ((...idxTup), lastRank.low),
high = ((...idxTup), lastRank.high),
flatSlice = ord.indexToOrder(low)..ord.indexToOrder(high);

unflat[(...rrSlice)] = aFlat[flatSlice];
}

return unflat;
}

// TODO: put this in AryUtil or some other common module after merging with #3056
private proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank > 1
{
var flat = makeDistArray({0..<d.size}, t);
const rankLast = d.dim(d.rank-1);

// iterate over each slice of the input array along the last dimension
// and copy the data into the corresponding slice of the flat array
forall idx in domOffAxis(d, d.rank-1) with (const ord = new orderer(d.shape)) {
var idxTup: (d.rank-1)*int;
for i in 0..<(d.rank-1) do idxTup[i] = idx[i];
const rrSlice = ((...idxTup), rankLast);

const low = ((...idxTup), rankLast.low),
high = ((...idxTup), rankLast.high),
flatSlice = ord.indexToOrder(low)..ord.indexToOrder(high);

flat[flatSlice] = a[(...rrSlice)];
}

return flat;
}

record orderer {
param rank: int;
const accumRankSizes: [0..<rank] int;

proc init(shape: ?N*int) {
this.rank = N;
const sizesRev = [i in 0..<N] shape[N - i - 1];
this.accumRankSizes = * scan sizesRev / sizesRev;
}

// index -> order for the input array's indices
// e.g., order = k + (nz * j) + (nz * ny * i)
inline proc indexToOrder(idx: rank*int): int {
var order = 0;
for param i in 0..<rank do order += idx[i] * accumRankSizes[rank - i - 1];
return order;
}
}
}
Loading

0 comments on commit 180b7b8

Please sign in to comment.