Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse sum helper to util #2976

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions arkouda/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,45 @@ def attach_all(names: list):
return {n: attach(n) for n in names}


def sparse_sum_help(idx1, idx2, val1, val2):
"""
Helper for summing two sparse matrices together

Parameters
-----------
idx1: pdarray
indices for the first sparse matrix
idx2: pdarray
indices for the second sparse matrix
val1: pdarray
values for the first sparse matrix
val2: pdarray
values for the second sparse matrix

Returns
--------
(pdarray, pdarray)
indices and values for the summed sparse matrix

Examples
--------
>>> idx1 = ak.array([0, 1, 3, 4, 7, 9])
>>> idx2 = ak.array([0, 1, 3, 6, 9])
>>> vals1 = idx1
>>> vals2 = ak.array([10, 11, 13, 16, 19])
>>> ak.util.sparse_sum_help(idx1, inds2, vals1, vals2)
(array([0 1 3 4 6 7 9]), array([10 12 16 4 16 7 28]))

>>> ak.GroupBy(ak.concatenate([idx1, idx2])).sum(ak.concatenate((vals1, vals2)))
(array([0 1 3 4 6 7 9]), array([10 12 16 4 16 7 28]))
"""
repMsg = generic_msg(
cmd="sparseSumHelp", args={"idx1": idx1, "idx2": idx2, "val1": val1, "val2": val2}
)
inds, vals = repMsg.split("+", maxsplit=1)
return create_pdarray(inds), create_pdarray(vals)


def broadcast_dims(sa: Sequence[int], sb: Sequence[int]) -> Tuple[int, ...]:
"""
Algorithm to determine shape of broadcasted PD array given two array shapes
Expand Down
57 changes: 57 additions & 0 deletions src/ArraySetops.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module ArraySetops
use Indexing;
use AryUtil;
use In1d;
use CommAggregation;

// returns intersection of 2 arrays
proc intersect1d(a: [] ?t, b: [] t, assume_unique: bool) throws {
Expand Down Expand Up @@ -112,4 +113,60 @@ module ArraySetops
}
return uniqueSort(aux, false);
}

proc sortHelper(idx1: [?D] ?t, idx2: [] t, aSize, bSize) throws {
// all we are doing is merging two sorted lists, there has to
// be a better way than concatenating and doing a full sort

// eventually we want an if statement to determine if we want to merge or sort
const allocSize = idx1.size + idx2.size;
var sortedIdx = makeDistArray(allocSize, t);
var perm = makeDistArray(allocSize, int);
forall (s, p, sp) in zip(sortedIdx, perm, radixSortLSD(concatArrays(idx1, idx2))) {
(s, p) = sp;
}
return (sortedIdx, perm);
}

proc sparseSumHelper(const ref idx1: [] ?t, const ref idx2: [] t, const ref val1: [] ?t2, const ref val2: [] t2) throws {
const allocSize = idx1.size + idx2.size;
const (sortedIdx, perm) = sortHelper(idx1, idx2, idx1.size, idx2.size);

var permutedVals = makeDistArray(allocSize, t2);
const vals = concatArrays(val1, val2);
forall (p, i) in zip(permutedVals, perm) with (var agg = newSrcAggregator(t2)) {
agg.copy(p, vals[i]);
}

const sD = sortedIdx.domain;
var firstOccurence = makeDistArray(sD, bool);
firstOccurence[0] = true;
forall (f, s, i) in zip(firstOccurence, sortedIdx, sD) {
if i > sD.low {
// most of the time sortedIdx[i-1] should be local since we are block distributed,
// so we only have to fetch at locale boundaries
f = (sortedIdx[i-1] != s);
}
}
const numUnique = + reduce firstOccurence;
// we have to do a first pass through data to calculate the size of the return array
var uIdx = makeDistArray(numUnique, t);
var ret = makeDistArray(numUnique, t2);
const retIdx = + scan firstOccurence - firstOccurence;
forall (s, p, i, f, rIdx) in zip(sortedIdx, permutedVals, sD, firstOccurence, retIdx) with (var idxAgg = newDstAggregator(t),
var valAgg = newDstAggregator(t2)) {
if f { // skip if we are not the first occurence
idxAgg.copy(uIdx[rIdx], s);
if i == sD.high || sortedIdx[i+1] != s {
valAgg.copy(ret[rIdx], p);
}
else {
// i'd like to do aggregation but I think it's possible for remote-to-remote aggregation?
// valAgg.copy(ret[rIdx], p + permutedVals[i+1]);
ret[rIdx] = p + permutedVals[i+1];
}
}
}
return (uIdx, ret);
}
}
63 changes: 63 additions & 0 deletions src/ArraySetopsMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,73 @@ module ArraySetopsMsg
moduleName=getModuleName(),
errorClass="ErrorWithContext");
}

proc sparseSumHelpMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
param pn = Reflection.getRoutineName();
var repMsg: string; // response message

var iname = st.nextName();
var vname = st.nextName();

const gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("idx1"), st);
const gEnt2: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("idx2"), st);
const gEnt3: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("val1"), st);
const gEnt4: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("val2"), st);

const gEnt_sortMem = radixSortLSD_memEst(gEnt.size, gEnt.itemsize);
const gEnt2_sortMem = radixSortLSD_memEst(gEnt2.size, gEnt2.itemsize);
const union_maxMem = max(gEnt_sortMem, gEnt2_sortMem);
overMemLimit(union_maxMem);

select(gEnt.dtype, gEnt2.dtype, gEnt3.dtype, gEnt4.dtype) {
when (DType.Int64, DType.Int64, DType.Int64, DType.Int64) {
const e = toSymEntry(gEnt,int);
const f = toSymEntry(gEnt2,int);
const g = toSymEntry(gEnt3,int);
const h = toSymEntry(gEnt4,int);
const ref ea = e.a;
const ref fa = f.a;
const ref ga = g.a;
const ref ha = h.a;

const (retIdx, retVals) = sparseSumHelper(ea, fa, ga, ha);
st.addEntry(iname, createSymEntry(retIdx));
st.addEntry(vname, createSymEntry(retVals));

const repMsg = "created " + st.attrib(iname) + "+created " + st.attrib(vname);
asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
when (DType.Int64, DType.Int64, DType.UInt64, DType.UInt64) {
const e = toSymEntry(gEnt,int);
const f = toSymEntry(gEnt2,int);
const g = toSymEntry(gEnt3,uint);
const h = toSymEntry(gEnt4,uint);
const ref ea = e.a;
const ref fa = f.a;
const ref ga = g.a;
const ref ha = h.a;

const (retIdx, retVals) = sparseSumHelper(ea, fa, ga, ha);
st.addEntry(iname, createSymEntry(retIdx));
st.addEntry(vname, createSymEntry(retVals));

const repMsg = "created " + st.attrib(iname) + "+created " + st.attrib(vname);
asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
otherwise {
var errorMsg = notImplementedError("sparseSumHelper",gEnt.dtype);
asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}

use CommandMap;
registerFunction("intersect1d", intersect1dMsg, getModuleName());
registerFunction("setdiff1d", setdiff1dMsg, getModuleName());
registerFunction("setxor1d", setxor1dMsg, getModuleName());
registerFunction("union1d", union1dMsg, getModuleName());
registerFunction("sparseSumHelp", sparseSumHelpMsg, getModuleName());
}
Loading