Skip to content

Commit

Permalink
[MINOR] Improved code coverage control program and symbol table
Browse files Browse the repository at this point in the history
This patch adds a couple of tests to systematically fix uncovered code.
Furthermore, it removes incorrect and renames misleading methods on
"pinned objects" that actually did not deal with our notion of pinned
(i.e., disabled cleanup) data objects.
  • Loading branch information
mboehm7 committed Dec 10, 2024
1 parent 082cf89 commit 0242557
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class PreparedScript implements ConfigurableAPI
private PreparedScript(PreparedScript that) {
//shallow copy, except for a separate symbol table
//and related meta data of reused inputs
_prog = that._prog.clone(false);
_prog = (Program)that._prog.clone();
_vars = new LocalVariableMap();
for(Entry<String, Data> e : that._vars.entrySet())
_vars.put(e.getKey(), e.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class LocalVariableMap implements Cloneable

//variable map data and id
private final ConcurrentHashMap<String, Data> localMap;
private final long localID;
private long localID;

//optional set of registered outputs
private HashSet<String> outputs = null;
Expand All @@ -61,6 +61,10 @@ public LocalVariableMap(LocalVariableMap vars) {
localMap = new ConcurrentHashMap<>(vars.localMap);
localID = _seq.getNextID();
}

public void setID(long ID) {
localID = ID;
}

public Set<String> keySet() {
return localMap.keySet();
Expand Down Expand Up @@ -154,12 +158,7 @@ public double getPinnedDataSize() {
return total;
}

public long countPinnedData() {
return localMap.values().stream()
.filter(d -> (d instanceof CacheableData)).count();
}

public void releasePinnedData() {
public void releaseAcquiredData() {
localMap.values().stream()
.filter(d -> (d instanceof CacheableData))
.map(d -> (CacheableData<?>) d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.HashMap;
import java.util.Map.Entry;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.FunctionDictionary;
import org.apache.sysds.runtime.DMLRuntimeException;
Expand Down Expand Up @@ -165,9 +164,8 @@ public void execute(ExecutionContext ec) {
}
}

public Program clone(boolean deep) {
if( deep )
throw new NotImplementedException();
@Override
public Object clone() {
Program ret = new Program(_prog);
//shallow copy of all program blocks
ret._programBlocks.addAll(_programBlocks);
Expand All @@ -179,11 +177,6 @@ public Program clone(boolean deep) {
return ret;
}

@Override
public Object clone() {
return clone(true);
}

private static String getSafeNamespace(String namespace) {
return (namespace == null) ? DMLProgram.DEFAULT_NAMESPACE : namespace;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
Expand All @@ -60,7 +57,7 @@ public abstract class ProgramBlock implements ParseInfo {
public static final String PRED_VAR = "__pred";

protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName());
private static final boolean CHECK_MATRIX_PROPERTIES = false;
public static boolean CHECK_MATRIX_PROPERTIES = false;

protected Program _prog; // pointer to Program this ProgramBlock is part of

Expand All @@ -84,10 +81,6 @@ public Program getProgram() {
return _prog;
}

public void setProgram(Program prog) {
_prog = prog;
}

public StatementBlock getStatementBlock() {
return _sb;
}
Expand Down Expand Up @@ -216,22 +209,7 @@ protected ScalarObject executePredicateInstructions(ArrayList<Instruction> inst,

// check and correct scalar ret type (incl save double to int)
if(retType != null && retType != ret.getValueType())
switch(retType) {
case BOOLEAN:
ret = new BooleanObject(ret.getBooleanValue());
break;
case INT64:
ret = new IntObject(ret.getLongValue());
break;
case FP64:
ret = new DoubleObject(ret.getDoubleValue());
break;
case STRING:
ret = new StringObject(ret.getStringValue());
break;
default:
// do nothing
}
ret = ScalarObjectFactory.createScalarObject(retType, ret);

// remove predicate variable
ec.removeVariable(PRED_VAR);
Expand Down Expand Up @@ -350,12 +328,10 @@ private static void checkSparsity(Instruction lastInst, LocalVariableMap vars, E
synchronized(mb) { // potential state change
mb.recomputeNonZeros();
mb.examSparsity();

}
if(mb.isInSparseFormat() && mb.isAllocated()) {
mb.getSparseBlock().checkValidity(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), true);
}

boolean sparse2 = mb.isInSparseFormat();
long nnz2 = mb.getNonZeros();
mo.release();
Expand Down Expand Up @@ -473,11 +449,11 @@ public String printBlockErrorLocation() {
* position, ending column position, text, and filename
*/
public void setParseInfo(ParseInfo parseInfo) {
_beginLine = parseInfo.getBeginLine();
_beginColumn = parseInfo.getBeginColumn();
_endLine = parseInfo.getEndLine();
_endColumn = parseInfo.getEndColumn();
_text = parseInfo.getText();
_filename = parseInfo.getFilename();
setBeginLine(parseInfo.getBeginLine());
setBeginColumn(parseInfo.getBeginColumn());
setEndLine(parseInfo.getEndLine());
setEndColumn(parseInfo.getEndColumn());
setText(parseInfo.getText());
setFilename(parseInfo.getFilename());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,9 @@ private static void exec(ExecutionContext ec, Instruction ins){
pb.execute(ec);
}
catch(Exception ex) {
// ensure all variables are properly unpinned, even in case
// ensure all variables are properly released, even in case
// of failures because federated workers are stateful servers
ec.getVariables().releasePinnedData();
ec.getVariables().releaseAcquiredData();
throw ex;
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/runtime/meta/MetaData.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ public class MetaData
{
protected final DataCharacteristics _dc;

public MetaData() {
this(new MatrixCharacteristics());
}

public MetaData(DataCharacteristics dc) {
_dc = dc;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.component.cp;

import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.junit.Assert;
import org.junit.Test;

public class VariableMapTest {

@Test
public void testPinnedMethods() {
LocalVariableMap vars = createSymbolTable();
Assert.assertTrue(vars.getPinnedDataSize() > 2e5);
vars.releaseAcquiredData(); //no impact on pinned status
Assert.assertTrue(vars.getPinnedDataSize() > 2e5);
vars.removeAll();
}

@Test
public void testSerializeDeserialize() {
LocalVariableMap vars = createSymbolTable();
LocalVariableMap vars2 = LocalVariableMap.deserialize(vars.serialize());
vars2.setID(1);
Assert.assertEquals(vars.toString(), vars2.toString());
LocalVariableMap vars3 = (LocalVariableMap) vars2.clone();
vars3.setID(1);
Assert.assertEquals(vars.toString(), vars3.toString());
}

private LocalVariableMap createSymbolTable() {
LocalVariableMap vars = new LocalVariableMap();
vars.put("a", createPinnedMatrixObject(1));
vars.put("b", createPinnedMatrixObject(2));
return vars;
}

private MatrixObject createPinnedMatrixObject(int seed) {
MatrixBlock mb1 = MatrixBlock.randOperations(150, 167, 0.3, 1, 1, "uniform", seed);
MatrixObject mo = new MatrixObject(ValueType.FP64, "./tmp",
new MetaDataFormat(new MatrixCharacteristics(), FileFormat.BINARY));
mo.acquireModify(mb1);
mo.release();
mo.enableCleanup(false);
mo.setDirty(false);
return mo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggBinaryOp.MMultMethod;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
Expand Down Expand Up @@ -116,12 +117,13 @@ private void runMatrixMatrixMultiplicationTest( boolean sparseM1, boolean sparse
if( rtplatform == ExecMode.SPARK )
DMLScript.USE_LOCAL_SPARK_CONFIG = true;

ProgramBlock.CHECK_MATRIX_PROPERTIES = true;

if(forcePMMJ)
AggBinaryOp.FORCED_MMULT_METHOD = MMultMethod.PMM;

try
{
setOutputBuffering(true);
String TEST_NAME = (rowwise) ? TEST_NAME1 : TEST_NAME2;
getAndLoadTestConfiguration(TEST_NAME);

Expand Down Expand Up @@ -154,6 +156,7 @@ private void runMatrixMatrixMultiplicationTest( boolean sparseM1, boolean sparse
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
AggBinaryOp.FORCED_MMULT_METHOD = null;
ProgramBlock.CHECK_MATRIX_PROPERTIES = false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
Expand Down Expand Up @@ -116,7 +117,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
Process t4 = startLocalFedWorker(port4);


ProgramBlock.CHECK_MATRIX_PROPERTIES = true;
try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting federated worker");
Expand Down Expand Up @@ -162,6 +163,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
ProgramBlock.CHECK_MATRIX_PROPERTIES = false;
}
}
}

0 comments on commit 0242557

Please sign in to comment.