diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 08b7425240d..31bb7457227 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -83,7 +83,7 @@ public class PreparedScript implements ConfigurableAPI private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs - _prog = that._prog.clone(false); + _prog = (Program)that._prog.clone(); _vars = new LocalVariableMap(); for(Entry e : that._vars.entrySet()) _vars.put(e.getKey(), e.getValue()); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java index aa6184a5f45..8d27a6e8f29 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java @@ -47,7 +47,7 @@ public class LocalVariableMap implements Cloneable //variable map data and id private final ConcurrentHashMap localMap; - private final long localID; + private long localID; //optional set of registered outputs private HashSet outputs = null; @@ -61,6 +61,10 @@ public LocalVariableMap(LocalVariableMap vars) { localMap = new ConcurrentHashMap<>(vars.localMap); localID = _seq.getNextID(); } + + public void setID(long ID) { + localID = ID; + } public Set keySet() { return localMap.keySet(); @@ -154,12 +158,7 @@ public double getPinnedDataSize() { return total; } - public long countPinnedData() { - return localMap.values().stream() - .filter(d -> (d instanceof CacheableData)).count(); - } - - public void releasePinnedData() { + public void releaseAcquiredData() { localMap.values().stream() .filter(d -> (d instanceof CacheableData)) .map(d -> (CacheableData) d) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java index 73ed572114d..e79e48433f2 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java @@ -23,7 +23,6 @@ import java.util.HashMap; import java.util.Map.Entry; -import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.FunctionDictionary; import org.apache.sysds.runtime.DMLRuntimeException; @@ -165,9 +164,8 @@ public void execute(ExecutionContext ec) { } } - public Program clone(boolean deep) { - if( deep ) - throw new NotImplementedException(); + @Override + public Object clone() { Program ret = new Program(_prog); //shallow copy of all program blocks ret._programBlocks.addAll(_programBlocks); @@ -179,11 +177,6 @@ public Program clone(boolean deep) { return ret; } - @Override - public Object clone() { - return clone(true); - } - private static String getSafeNamespace(String namespace) { return (namespace == null) ? DMLProgram.DEFAULT_NAMESPACE : namespace; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index 07393346803..aee08516db6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -41,12 +41,9 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; -import org.apache.sysds.runtime.instructions.cp.BooleanObject; import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.instructions.cp.DoubleObject; -import org.apache.sysds.runtime.instructions.cp.IntObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; -import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.lineage.LineageCache; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; @@ -60,7 +57,7 @@ public abstract class ProgramBlock implements ParseInfo { public static final String PRED_VAR = "__pred"; protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName()); - private static final boolean CHECK_MATRIX_PROPERTIES = false; + public static boolean CHECK_MATRIX_PROPERTIES = false; protected Program _prog; // pointer to Program this ProgramBlock is part of @@ -84,10 +81,6 @@ public Program getProgram() { return _prog; } - public void setProgram(Program prog) { - _prog = prog; - } - public StatementBlock getStatementBlock() { return _sb; } @@ -216,22 +209,7 @@ protected ScalarObject executePredicateInstructions(ArrayList inst, // check and correct scalar ret type (incl save double to int) if(retType != null && retType != ret.getValueType()) - switch(retType) { - case BOOLEAN: - ret = new BooleanObject(ret.getBooleanValue()); - break; - case INT64: - ret = new IntObject(ret.getLongValue()); - break; - case FP64: - ret = new DoubleObject(ret.getDoubleValue()); - break; - case STRING: - ret = new StringObject(ret.getStringValue()); - break; - default: - // do nothing - } + ret = ScalarObjectFactory.createScalarObject(retType, ret); // remove predicate variable ec.removeVariable(PRED_VAR); @@ -350,12 +328,10 @@ private static void checkSparsity(Instruction lastInst, LocalVariableMap vars, E synchronized(mb) { // potential state change mb.recomputeNonZeros(); mb.examSparsity(); - } if(mb.isInSparseFormat() && mb.isAllocated()) { mb.getSparseBlock().checkValidity(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), true); } - boolean sparse2 = mb.isInSparseFormat(); long nnz2 = mb.getNonZeros(); mo.release(); @@ -473,11 +449,11 @@ public String printBlockErrorLocation() { * position, ending column position, text, and filename */ public void setParseInfo(ParseInfo parseInfo) { - _beginLine = parseInfo.getBeginLine(); - _beginColumn = parseInfo.getBeginColumn(); - _endLine = parseInfo.getEndLine(); - _endColumn = parseInfo.getEndColumn(); - _text = parseInfo.getText(); - _filename = parseInfo.getFilename(); + setBeginLine(parseInfo.getBeginLine()); + setBeginColumn(parseInfo.getBeginColumn()); + setEndLine(parseInfo.getEndLine()); + setEndColumn(parseInfo.getEndColumn()); + setText(parseInfo.getText()); + setFilename(parseInfo.getFilename()); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index d8d3c262efe..ceaf61c225c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -619,9 +619,9 @@ private static void exec(ExecutionContext ec, Instruction ins){ pb.execute(ec); } catch(Exception ex) { - // ensure all variables are properly unpinned, even in case + // ensure all variables are properly released, even in case // of failures because federated workers are stateful servers - ec.getVariables().releasePinnedData(); + ec.getVariables().releaseAcquiredData(); throw ex; } } diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaData.java b/src/main/java/org/apache/sysds/runtime/meta/MetaData.java index aa820b5cb4d..925fe7f1867 100644 --- a/src/main/java/org/apache/sysds/runtime/meta/MetaData.java +++ b/src/main/java/org/apache/sysds/runtime/meta/MetaData.java @@ -27,6 +27,10 @@ public class MetaData { protected final DataCharacteristics _dc; + public MetaData() { + this(new MatrixCharacteristics()); + } + public MetaData(DataCharacteristics dc) { _dc = dc; } diff --git a/src/test/java/org/apache/sysds/test/component/cp/VariableMapTest.java b/src/test/java/org/apache/sysds/test/component/cp/VariableMapTest.java new file mode 100644 index 00000000000..74263fff803 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/cp/VariableMapTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.cp; + +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.controlprogram.LocalVariableMap; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.junit.Assert; +import org.junit.Test; + +public class VariableMapTest { + + @Test + public void testPinnedMethods() { + LocalVariableMap vars = createSymbolTable(); + Assert.assertTrue(vars.getPinnedDataSize() > 2e5); + vars.releaseAcquiredData(); //no impact on pinned status + Assert.assertTrue(vars.getPinnedDataSize() > 2e5); + vars.removeAll(); + } + + @Test + public void testSerializeDeserialize() { + LocalVariableMap vars = createSymbolTable(); + LocalVariableMap vars2 = LocalVariableMap.deserialize(vars.serialize()); + vars2.setID(1); + Assert.assertEquals(vars.toString(), vars2.toString()); + LocalVariableMap vars3 = (LocalVariableMap) vars2.clone(); + vars3.setID(1); + Assert.assertEquals(vars.toString(), vars3.toString()); + } + + private LocalVariableMap createSymbolTable() { + LocalVariableMap vars = new LocalVariableMap(); + vars.put("a", createPinnedMatrixObject(1)); + vars.put("b", createPinnedMatrixObject(2)); + return vars; + } + + private MatrixObject createPinnedMatrixObject(int seed) { + MatrixBlock mb1 = MatrixBlock.randOperations(150, 167, 0.3, 1, 1, "uniform", seed); + MatrixObject mo = new MatrixObject(ValueType.FP64, "./tmp", + new MetaDataFormat(new MatrixCharacteristics(), FileFormat.BINARY)); + mo.acquireModify(mb1); + mo.release(); + mo.enableCleanup(false); + mo.setDirty(false); + return mo; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java index 6c712a8ad2b..0fbe20a3201 100644 --- a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java @@ -27,6 +27,7 @@ import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggBinaryOp.MMultMethod; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.controlprogram.ProgramBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; @@ -116,12 +117,13 @@ private void runMatrixMatrixMultiplicationTest( boolean sparseM1, boolean sparse if( rtplatform == ExecMode.SPARK ) DMLScript.USE_LOCAL_SPARK_CONFIG = true; + ProgramBlock.CHECK_MATRIX_PROPERTIES = true; + if(forcePMMJ) AggBinaryOp.FORCED_MMULT_METHOD = MMultMethod.PMM; try { - setOutputBuffering(true); String TEST_NAME = (rowwise) ? TEST_NAME1 : TEST_NAME2; getAndLoadTestConfiguration(TEST_NAME); @@ -154,6 +156,7 @@ private void runMatrixMatrixMultiplicationTest( boolean sparseM1, boolean sparse rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; AggBinaryOp.FORCED_MMULT_METHOD = null; + ProgramBlock.CHECK_MATRIX_PROPERTIES = false; } } } \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java index 8e45c7347df..46bf7a45650 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.controlprogram.ProgramBlock; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.test.AutomatedTestBase; @@ -116,7 +117,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); Process t4 = startLocalFedWorker(port4); - + ProgramBlock.CHECK_MATRIX_PROPERTIES = true; try { if(!isAlive(t1, t2, t3, t4)) throw new RuntimeException("Failed starting federated worker"); @@ -162,6 +163,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; + ProgramBlock.CHECK_MATRIX_PROPERTIES = false; } } }