diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index 049e91a638e..8e67f283d07 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -20,10 +20,10 @@ package org.apache.sysds.hops.fedplanner; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -139,19 +139,68 @@ private void printFedPlanTreeRecursive(FedPlan plan, Set visited, int d visited.add(plan); - // Create indentation and connectors for tree visualization - String indent = " ".repeat(depth); - String prefix = depth == 0 ? "└──" : - isLast ? "└─" : "├─"; - - // Print plan information - System.out.printf("%s%sHop %d [%s] (Total: %.3f, Self: %.3f, Net: %.3f)%n", - indent, prefix, - plan.getHopRef().getHopID(), - plan.getFedOutType(), - plan.getTotalCost(), - plan.getSelfCost(), - plan.getNetTransferCost()); + Hop hop = plan.getHopRef(); + StringBuilder sb = new StringBuilder(); + + // Add FedPlan information + sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) + .append(plan.getHopRef().getOpString()) + .append(" [") + .append(plan.getFedOutType()) + .append("]"); + + StringBuilder childs = new StringBuilder(); + childs.append(" ("); + boolean childAdded = false; + for( Hop input : hop.getInput()){ + childs.append(childAdded?",":""); + childs.append(input.getHopID()); + childAdded = true; + } + childs.append(")"); + if( childAdded ) + sb.append(childs.toString()); + + + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", + plan.getTotalCost(), + plan.getSelfCost(), + plan.getNetTransferCost())); + + // Add matrix characteristics + sb.append(" [") + .append(hop.getDim1()).append(", ") + .append(hop.getDim2()).append(", ") + .append(hop.getBlocksize()).append(", ") + .append(hop.getNnz()); + + if (hop.getUpdateType().isInPlace()) { + sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); + } + sb.append("]"); + + // Add memory estimates + sb.append(" [") + .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") + .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); + + // Add reblock and checkpoint requirements + if (hop.requiresReblock() && hop.requiresCheckpoint()) { + sb.append(" [rblk, chkpt]"); + } else if (hop.requiresReblock()) { + sb.append(" [rblk]"); + } else if (hop.requiresCheckpoint()) { + sb.append(" [chkpt]"); + } + + // Add execution type + if (hop.getExecType() != null) { + sb.append(", ").append(hop.getExecType()); + } + + System.out.println(sb); // Process child nodes List> childRefs = plan.getChildFedPlans(); diff --git a/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java index 80c8d47f435..57ecac158a1 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java @@ -39,7 +39,7 @@ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase { - private static final String TEST_DIR = "component/parfor/"; + private static final String TEST_DIR = "functions/federated/privacy/"; private static final String HOME = SCRIPT_DIR + TEST_DIR; private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; @@ -47,20 +47,7 @@ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase public void setUp() {} @Test - public void testDependencyAnalysis1() { runTest("parfor1.dml"); } - - @Test - public void testDependencyAnalysis3() { runTest("parfor3.dml"); } - - @Test - public void testDependencyAnalysis4() { runTest("parfor4.dml"); } - - @Test - public void testDependencyAnalysis6() { runTest("parfor6.dml"); } - - @Test - public void testDependencyAnalysis7() { runTest("parfor7.dml"); } - + public void testDependencyAnalysis1() { runTest("cost.dml"); } private void runTest( String scriptFilename ) { int index = scriptFilename.lastIndexOf(".dml"); @@ -83,7 +70,8 @@ private void runTest( String scriptFilename ) { dmlt.liveVariableAnalysis(prog); dmlt.validateParseTree(prog); dmlt.constructHops(prog); - + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); /* TODO) In the current DAG, Hop's _outputMemEstimate is not initialized // This leads to incorrect fedplan generation, so test code needs to be modified // If needed, modify costEstimator to handle cases where _outputMemEstimate is not initialized diff --git a/src/test/scripts/functions/federated/privacy/cost.dml b/src/test/scripts/functions/federated/privacy/cost.dml new file mode 100644 index 00000000000..ec34d45bb65 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/cost.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = matrix(7,10,10); +b = a + a^2; +c = sqrt(b); +print(sum(c)); \ No newline at end of file