Skip to content

Commit

Permalink
[SYSTEMDS-3765] Fix time displacement through function hoisting
Browse files Browse the repository at this point in the history
This patch fixes issues with time() functions which are used to
measure execution time of parts of a program. When these functions
were used in expressions (e.g., print string concatenation) the normal
DAG compilation might move them before the operation that was actually
measured. Similar to DML function calls, we now hoist these time
functions out of expressions.
  • Loading branch information
mboehm7 committed Dec 7, 2024
1 parent 34a6571 commit b5b6f37
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
9 changes: 8 additions & 1 deletion src/main/java/org/apache/sysds/parser/StatementBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,13 @@ else if( expr instanceof BuiltinFunctionExpression ) {
Expression[] clexpr = lexpr.getAllExpr();
for( int i=0; i<clexpr.length; i++ )
clexpr[i] = rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp, prog);
if( !root && lexpr.getOpCode()==Builtins.TIME ) { //core time hoisting
String varname = StatementBlockRewriteRule.createCutVarName(true);
DataIdentifier di = new DataIdentifier(varname);
di.setDataType(lexpr.getDataType());
di.setValueType(lexpr.getValueType());
tmp.add(new AssignmentStatement(di, lexpr, di));
}
}
else if( expr instanceof ParameterizedBuiltinFunctionExpression ) {
ParameterizedBuiltinFunctionExpression lexpr = (ParameterizedBuiltinFunctionExpression) expr;
Expand All @@ -612,7 +619,7 @@ else if( expr instanceof FunctionCallIdentifier ) {
FunctionCallIdentifier fexpr = (FunctionCallIdentifier) expr;
for( ParameterExpression pexpr : fexpr.getParamExprs() )
pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp, prog));
if( !root ) { //core hoisting
if( !root ) { //core fcall hoisting
String varname = StatementBlockRewriteRule.createCutVarName(true);
DataIdentifier di = new DataIdentifier(varname);
di.setDataType(fexpr.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.sysds.test.functions.rewrite;

import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
Expand All @@ -43,9 +44,14 @@ public void setUp() {
}

@Test
public void testTimeHoisting() {
public void testTimeHoistingCP() {
test(TEST_NAME1, ExecType.CP);
}

@Test
public void testTimeHoistingSpark() {
test(TEST_NAME1, ExecType.SPARK);
}

private void test(String testname, ExecType et)
{
Expand All @@ -58,11 +64,15 @@ private void test(String testname, ExecType et)

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] { "-explain", "-args",
programArgs = new String[] {"-args",
String.valueOf(rows), String.valueOf(cols) };

//FIXME need to hoist time() out of expression similar to function calls
runTest(true, false, null, -1);

//test that time is not executed before 1k-by-1k rand
setOutputBuffering(true);
String out = runTest(true, false, null, -1).toString();
double time = Double.parseDouble(out.split(";")[1]);
System.out.println("Time = "+time+"s");
Assert.assertTrue(time>0.001);
}
finally {
resetExecMode(platformOld);
Expand Down
2 changes: 1 addition & 1 deletion src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
t1 = time();
X = rand(rows=$1, cols=$2);

print("time = "+(time()-t1)/1e9+"s"+" "+sum(X));
print(";"+(time()-t1)/1e9+";"+" "+sum(X));

0 comments on commit b5b6f37

Please sign in to comment.