Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize creation of sparrays from pdarrays #3877

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions src/SparseMatrix.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -471,32 +471,79 @@ module SparseMatrix {
import SymArrayDmap.makeSparseDomain;
var (SD, dense) = makeSparseDomain(shape, layout);

for i in 0..<rows.size {
if SD.contains((rows[i], cols[i])) then
const minRow = min reduce rows;
const maxRow = max reduce rows;
const minCol = min reduce cols;
const maxCol = max reduce cols;

if minRow < 0 || minCol < 0 then
throw getErrorWithContext(
msg="Duplicate index (%i, %i) in sparse matrix".format(rows[i], cols[i]),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="InvalidArgumentError"
);
if rows[i] < 1 || rows[i] > shape[0] || cols[i] < 1 || cols[i] > shape[1] then
msg="Sparse matrix indices must be greater than 0; got (%i, %i)".format(minRow, minCol), // TODO, change this when we start matrix from 0 instead of 1
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="InvalidArgumentError"
);
if maxRow >= shape[0] || maxCol >= shape[1] then
throw getErrorWithContext(
bmcdonald3 marked this conversation as resolved.
Show resolved Hide resolved
msg="Sparse matrix indices must be less than the shape; got (%i, %i) >= (%i, %i)".format(maxRow, maxCol, shape[0], shape[1]),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="InvalidArgumentError"
);
bmcdonald3 marked this conversation as resolved.
Show resolved Hide resolved

var A: [SD] eltType;
addElementsToSparseArray(A, SD, rows, cols, vals);

return A;
}

proc addElementsToSparseArray(ref A, ref SD, const ref rows, const ref cols, const ref vals) throws where A.chpl_isNonDistributedArray() {
for (r,c,v) in zip(rows, cols, vals) {
if A.domain.contains(r,c) then
throw getErrorWithContext(
msg="Index (%i, %i) out of bounds for sparse matrix of shape (%i, %i)".format(rows[i], cols[i], shape[0], shape[1]),
msg="Duplicate index (%i, %i) in sparse matrix".format(r, c),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="InvalidArgumentError"
);
SD += (r,c);
A[r,c] = v;
}
}


proc addElementsToSparseArray(ref A, ref SD, const ref rows, const ref cols, const ref vals) throws where !A.chpl_isNonDistributedArray() {
coforall (loc, locDom) in zip(getGrid(A),
SD._value.locDoms) {
bmcdonald3 marked this conversation as resolved.
Show resolved Hide resolved
on loc {
for _srcLocId in loc.id..#numLocales {
const srcLocId = _srcLocId % numLocales;
var rowChunk = rows[rows.localSubdomain(Locales[srcLocId])];
var colChunk = cols[rows.localSubdomain(Locales[srcLocId])];
var valChunk = vals[rows.localSubdomain(Locales[srcLocId])];
for (r,c,v) in zip(rowChunk, colChunk, valChunk) {
if locDom!.parentDom.contains(r,c) {
if locDom!.mySparseBlock.contains(r,c) then
throw getErrorWithContext(
bmcdonald3 marked this conversation as resolved.
Show resolved Hide resolved
msg="Duplicate index (%i, %i) in sparse matrix".format(r, c),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="InvalidArgumentError"
);
SD += (rows[i], cols[i]);
}

var A: [SD] eltType;
for i in 0..<rows.size {
A[rows[i], cols[i]] = vals[i];

locDom!.mySparseBlock += (r,c);
A[r,c] = v;
}
}
}
}
}

return A;
}

module SpsMatUtil {
Expand Down
Loading