Skip to content

Commit

Permalink
[SYSTEMDS-3806] Robustness simplifyDotProductSum rewrite
Browse files Browse the repository at this point in the history
This patch fixes an issue of incorrect application of the
simplifyDotProductSum rewrite. Specifically, sum(s*V) was rewritten to
t(s) %*% V because s was assumed to be a vector of equal size than V
but was a scalar. The root cause of an incorrect size propagation for
the new scalar right indexing, but for robustness we now also check
that both inputs are actually matrices.
  • Loading branch information
mboehm7 committed Dec 13, 2024
1 parent 0743613 commit 1e86da3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/main/java/org/apache/sysds/hops/IndexingOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,14 @@ private static IndexingMethod optFindIndexingMethod( boolean singleRow, boolean
@Override
public void refreshSizeInformation()
{
// early abort for scalar right indexing
// (important to prevent incorrect dynamic rewrites)
if( isScalar() ) {
setDim1(0);
setDim2(0);
return;
}

Hop input1 = getInput().get(0); //matrix
Hop input2 = getInput().get(1); //inpRowL
Hop input3 = getInput().get(2); //inpRowU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,7 @@ private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
//check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
&& hi2.getInput().get(0).isMatrix() && hi2.getInput().get(1).isMatrix()
&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
&& ( !ALLOW_SUM_PRODUCT_REWRITES
Expand Down
2 changes: 0 additions & 2 deletions src/test/scripts/functions/unary/matrix/eigen.dml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ numEval = $2;
D = matrix(1, numEval, 1);
for ( i in 1:numEval ) {
Av = A %*% evec[,i];
while(FALSE){} #fix incorrect rewrite sequence
rhs = as.scalar(eval[i,1]) * evec[,i];
while(FALSE){} #fix incorrect rewrite sequence
diff = sum(Av-rhs);
D[i,1] = diff;
}
Expand Down

0 comments on commit 1e86da3

Please sign in to comment.