From 3b5e0bc1ef27349f5a1a03c66bbe5ee416b73334 Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Tue, 10 Dec 2024 19:18:13 +0100 Subject: [PATCH] [SYSTEMDS-3805] Fix scalar right indexing (only for valid indices) In order to ensure consistent error handling, we now only use the scalar right indexing if the index-range is within the matrix dims. --- .../cp/MatrixIndexingCPInstruction.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java index 26e4d3f45af..afbf7724ab0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java @@ -50,7 +50,7 @@ protected MatrixIndexingCPInstruction(CPOperand lhsInput, CPOperand rhsInput, CP @Override public void processInstruction(ExecutionContext ec) { String opcode = getOpcode(); - IndexRange ixrange = getIndexRange(ec); + IndexRange ix = getIndexRange(ec); //get original matrix MatrixObject mo = ec.getMatrixObject(input1.getName()); @@ -61,19 +61,19 @@ public void processInstruction(ExecutionContext ec) { MatrixBlock resultBlock = null; if( mo.isPartitioned() ) //via data partitioning - resultBlock = mo.readMatrixPartition(ixrange.add(1)); - else if( ixrange.isScalar() ){ + resultBlock = mo.readMatrixPartition(ix.add(1)); + else if( ix.isScalar() && ix.rowStart < mo.getNumRows() && ix.colStart < mo.getNumColumns() ) { MatrixBlock matBlock = mo.acquireReadAndRelease(); resultBlock = new MatrixBlock( - matBlock.get((int)ixrange.rowStart, (int)ixrange.colStart)); + matBlock.get((int)ix.rowStart, (int)ix.colStart)); } else //via slicing the in-memory matrix { //execute right indexing operation (with shallow row copies for range //of entire sparse rows, which is safe due to copy on update) MatrixBlock matBlock = mo.acquireRead(); - resultBlock = matBlock.slice((int)ixrange.rowStart, (int)ixrange.rowEnd, - (int)ixrange.colStart, (int)ixrange.colEnd, false, new MatrixBlock()); + resultBlock = matBlock.slice((int)ix.rowStart, (int)ix.rowEnd, + (int)ix.colStart, (int)ix.colEnd, false, new MatrixBlock()); //unpin rhs input ec.releaseMatrixInput(input1.getName()); @@ -101,15 +101,15 @@ else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE)) if(input2.getDataType() == DataType.MATRIX) { //MATRIX<-MATRIX MatrixBlock rhsMatBlock = ec.getMatrixInput(input2.getName()); - resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, ixrange, new MatrixBlock(), updateType); + resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, ix, new MatrixBlock(), updateType); ec.releaseMatrixInput(input2.getName()); } else { //MATRIX<-SCALAR - if(!ixrange.isScalar()) - throw new DMLRuntimeException("Invalid index range of scalar leftindexing: "+ixrange.toString()+"." ); + if(!ix.isScalar()) + throw new DMLRuntimeException("Invalid index range of scalar leftindexing: "+ix.toString()+"." ); ScalarObject scalar = ec.getScalarInput(input2.getName(), ValueType.FP64, input2.isLiteral()); resultBlock = matBlock.leftIndexingOperations(scalar, - (int)ixrange.rowStart, (int)ixrange.colStart, new MatrixBlock(), updateType); + (int)ix.rowStart, (int)ix.colStart, new MatrixBlock(), updateType); } //unpin lhs input