Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1869362: Plan plotter improvements #2813

Merged
merged 4 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SelectSnowflakePlan,
SelectStatement,
SelectTableFunction,
SelectableEntity,
SetStatement,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
Expand All @@ -28,6 +29,7 @@
LogicalPlan,
SnowflakeCreateTable,
TableCreationSource,
WithQueryBlock,
)
from snowflake.snowpark._internal.analyzer.table_merge_expression import (
TableDelete,
Expand Down Expand Up @@ -381,15 +383,29 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
):
return

if int(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this Plotting threshold used for? seems it is used for restricting the complexity score? maybe call this SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD to be more clear

os.environ.get("SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD", 0)
) > get_complexity_score(root):
return

import graphviz # pyright: ignore[reportMissingImports]

def get_stat(node: LogicalPlan):
def get_name(node: Optional[LogicalPlan]) -> str:
def get_name(node: Optional[LogicalPlan]) -> str: # pragma: no cover
if node is None:
return "EMPTY_SOURCE_PLAN" # pragma: no cover
addr = hex(id(node))
name = str(type(node)).split(".")[-1].split("'")[0]
return f"{name}({addr})"
suffix = ""
if isinstance(node, SnowflakeCreateTable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment here about what are the different printing used here

# get the table name from the full qualified name
table_name = node.table_name[-1].split(".")[-1] # pyright: ignore
suffix = f" :: {table_name}"
if isinstance(node, WithQueryBlock):
# get the CTE identifier excluding SNOWPARK_TEMP_CTE_
suffix = f" :: {node.name[18:]}"

return f"{name}({addr}){suffix}"

name = get_name(node)
if isinstance(node, SnowflakePlan):
Expand All @@ -411,20 +427,44 @@ def get_name(node: Optional[LogicalPlan]) -> str:
if node.offset:
properties.append("Offset") # pragma: no cover
name = f"{name} :: ({'| '.join(properties)})"
elif isinstance(node, SelectableEntity):
# get the table name from the full qualified name
name = f"{name} :: ({node.entity.name.split('.')[-1]})"

def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover
if isinstance(node, Selectable):
return node.sql_query
if isinstance(node, SnowflakePlan):
return node.queries[-1].sql
return ""

score = get_complexity_score(node)
num_ref_ctes = "nil"
if isinstance(node, (SnowflakePlan, Selectable)):
num_ref_ctes = len(node.referenced_ctes)
sql_text = ""
if isinstance(node, Selectable):
sql_text = node.sql_query
elif isinstance(node, SnowflakePlan):
sql_text = node.queries[-1].sql
sql_text = get_sql_text(node)
sql_size = len(sql_text)
ref_ctes = None
if isinstance(node, (SnowflakePlan, Selectable)):
ref_ctes = list(
map(
lambda node, cnt: f"{node.name[18:]}:{cnt}",
node.referenced_ctes.keys(),
node.referenced_ctes.values(),
)
)
for with_query_block in node.referenced_ctes: # pragma: no cover
sql_size += len(get_sql_text(with_query_block.children[0]))
sql_preview = sql_text[:50]

return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}"
return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}"

def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover
if isinstance(node, WithQueryBlock):
return True
if isinstance(node, SnowflakePlan):
return is_with_query_block(node.source_plan)
if isinstance(node, SelectSnowflakePlan):
return is_with_query_block(node.snowflake_plan)

return False

g = graphviz.Graph(format="png")

Expand All @@ -435,11 +475,18 @@ def get_name(node: Optional[LogicalPlan]) -> str:
for node in curr_level:
node_id = hex(id(node))
color = "lightblue" if node._is_valid_for_replacement else "red"
g.node(node_id, get_stat(node), color=color)
fillcolor = "lightgray" if is_with_query_block(node) else "white"
g.node(
node_id,
get_stat(node),
color=color,
style="filled",
fillcolor=fillcolor,
)
if isinstance(node, (Selectable, SnowflakePlan)):
children = node.children_plan_nodes
else:
children = node.children
children = node.children # pragma: no cover
for child in children:
child_id = hex(id(child))
edges.add((node_id, child_id))
Expand Down
20 changes: 17 additions & 3 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,16 +788,24 @@ def test_large_query_breakdown_enabled_parameter(session, caplog):

@pytest.mark.skipif(IS_IN_STORED_PROC, reason="requires graphviz")
@pytest.mark.parametrize("enabled", [False, True])
def test_plotter(session, large_query_df, enabled):
@pytest.mark.parametrize("plotting_score_threshold", [0, 10_000_000])
def test_plotter(large_query_df, enabled, plotting_score_threshold):
original_plotter_enabled = os.environ.get("ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING")
original_score_threshold = os.environ.get(
"SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"
)
try:
os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"] = str(enabled)
os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"] = str(
plotting_score_threshold
)
tmp_dir = tempfile.gettempdir()

with patch("graphviz.Graph.render") as mock_render:
large_query_df.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perhaps add a comment explaining that the actual complexity for large_query_df falls somewhere between 0 and 10M?

assert mock_render.called == enabled
if not enabled:
should_plot = enabled and (plotting_score_threshold == 0)
assert mock_render.called == should_plot
if not should_plot:
return

assert mock_render.call_count == 5
Expand All @@ -819,3 +827,9 @@ def test_plotter(session, large_query_df, enabled):
] = original_plotter_enabled
else:
del os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"]
if original_score_threshold is not None:
os.environ[
"SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"
] = original_score_threshold
else:
del os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"]
Loading