Skip to content

Commit

Permalink
Truncate long plan labels and refer to "print-plans" (#2821)
Browse files Browse the repository at this point in the history
* Truncate and link plan labels to "print-plans"

Properly treat graphviz compiled with the default 16K label max

Closes #2710.

Signed-off-by: Gera Shegalov <[email protected]>

* test and review
  • Loading branch information
gerashegalov authored Jun 28, 2021
1 parent 5a92298 commit 13b740c
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,11 @@ case class SparkPlanGraph(
physicalPlan: String) {

def makeDotFile(metrics: Map[Long, Long]): String = {
val leftAlignedLabel =
s"""
|Application: $appId
|Query: $sqlId
|
|$physicalPlan"""
.stripMargin
.replace("\n", "\\l")

val queryLabel = SparkPlanGraph.makeDotLabel(appId, sqlId, physicalPlan)

val dotFile = new StringBuilder
dotFile.append("digraph G {\n")
dotFile.append(s"""label="$leftAlignedLabel"\n""")
dotFile.append(s"label=$queryLabel\n")
dotFile.append("labelloc=b\n")
dotFile.append("fontname=Courier\n")
dotFile.append(s"""tooltip="APP: $appId Query: $sqlId"\n""")
Expand Down Expand Up @@ -330,6 +322,36 @@ object SparkPlanGraph {
exchanges, stageIdToStageMetrics))
}
}

val htmlLineBreak = """<br align="left"/>""" + "\n"

def makeDotLabel(
appId: String,
sqlId: String,
physicalPlan: String,
maxLength: Int = 16384
): String = {
val sqlPlanPlaceHolder = "%s"
val queryLabelFormat =
s"""<<table border="0">
|<tr><td>Application: $appId, Query: $sqlId</td></tr>
|<tr><td>$sqlPlanPlaceHolder</td></tr>
|<tr><td>Large physical plans may be truncated. See output from
|--print-plans captioned "Plan for SQL ID : $sqlId"
|</td></tr>
|</table>>""".stripMargin

// pre-calculate size post substitutions
val formatBytes = queryLabelFormat.length() - sqlPlanPlaceHolder.length()
val numLinebreaks = physicalPlan.count(_ == '\n')
val lineBreakBytes = numLinebreaks * htmlLineBreak.length()
val maxPlanLength = maxLength - formatBytes - lineBreakBytes

queryLabelFormat.format(
physicalPlan.take(maxPlanLength)
.replaceAll("\n", htmlLineBreak)
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
package com.nvidia.spark.rapids.tool.profiling

import java.io.File
import java.security.SecureRandom

import scala.collection.mutable
import scala.io.Source

import com.nvidia.spark.rapids.tool.ToolTestUtils
Expand Down Expand Up @@ -67,21 +69,78 @@ class GenerateDotSuite extends FunSuite with BeforeAndAfterAll with Logging {
for (file <- dotDirs) {
assert(file.getAbsolutePath.endsWith(".dot"))
val source = Source.fromFile(file)
try {
val lines = source.getLines().toArray
assert(lines.head === "digraph G {")
assert(lines.last === "}")
hashAggCount += lines.count(_.contains("HashAggregate"))
stageCount += lines.count(_.contains("STAGE "))
} finally {
source.close()
}
val dotFileStr = source.mkString
source.close()
assert(dotFileStr.startsWith("digraph G {"))
assert(dotFileStr.endsWith("}"))
val hashAggr = "HashAggregate"
val stageWord = "STAGE"
hashAggCount += dotFileStr.sliding(hashAggr.length).count(_ == hashAggr)
stageCount += dotFileStr.sliding(stageWord.length).count(_ == stageWord)
}
// 2 node labels + 1 graph label
assert(hashAggCount === 3)
// Initial Aggregation, Final Aggregation, Sorting final output
assert(stageCount === 3)

assert(hashAggCount === 8, "Expected: 4 in node labels + 4 in graph label")
assert(stageCount === 4, "Expected: UNKNOWN Stage, Initial Aggregation, " +
"Final Aggregation, Sorting final output")
}
}
}

test("Empty physical plan") {
val planLabel = SparkPlanGraph.makeDotLabel(
appId = "local-12345-1",
sqlId = "120",
physicalPlan = "")

planLabelChecks(planLabel)
}

test("Long physical plan") {
val random = new SecureRandom()
val seed = System.currentTimeMillis();
random.setSeed(seed);
info("Seeding test with: " + seed)
val numTests = 100

val lineLengthRange = 50 until 200


val planLengthSeq = mutable.ArrayBuffer.empty[Int]
val labelLengthSeq = mutable.ArrayBuffer.empty[Int]

// some imperfect randomness for edge cases
for (_ <- 1 to numTests) {
val lineLength = lineLengthRange.start +
random.nextInt(lineLengthRange.length) -
SparkPlanGraph.htmlLineBreak.length()

val sign = if (random.nextBoolean()) 1 else -1
val planLength = 16 * 1024 + sign * lineLength * (1 + random.nextInt(5));
val planStr = (0 to planLength / lineLength).map(_ => "a" * lineLength).mkString("\n")

planLengthSeq += planStr.length()

val planLabel = SparkPlanGraph.makeDotLabel(
appId = "local-12345-1",
sqlId = "120",
physicalPlan = planStr)

labelLengthSeq += planLabel.length()

planLabelChecks(planLabel)
assert(planLabel.length() <= 16 * 1024)
assert(planLabel.contains("a" * lineLength))
assert(planLabel.contains(SparkPlanGraph.htmlLineBreak))
}

info(s"Plan length summary: min=${labelLengthSeq.min} max=${labelLengthSeq.max}")
info(s"Plan label summary: min=${planLengthSeq.min} max=${planLengthSeq.max}")
}

private def planLabelChecks(planLabel: String) {
assert(planLabel.startsWith("<<table "))
assert(planLabel.endsWith("</table>>"))
assert(planLabel.contains("local-12345-1"))
assert(planLabel.contains("Plan for SQL ID : 120"))
}
}

0 comments on commit 13b740c

Please sign in to comment.