-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1869362: Plan plotter improvements #2813
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
SelectSnowflakePlan, | ||
SelectStatement, | ||
SelectTableFunction, | ||
SelectableEntity, | ||
SetStatement, | ||
) | ||
from snowflake.snowpark._internal.analyzer.snowflake_plan import ( | ||
|
@@ -28,6 +29,7 @@ | |
LogicalPlan, | ||
SnowflakeCreateTable, | ||
TableCreationSource, | ||
WithQueryBlock, | ||
) | ||
from snowflake.snowpark._internal.analyzer.table_merge_expression import ( | ||
TableDelete, | ||
|
@@ -381,15 +383,29 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None: | |
): | ||
return | ||
|
||
if int( | ||
os.environ.get("SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD", 0) | ||
) > get_complexity_score(root): | ||
return | ||
|
||
import graphviz # pyright: ignore[reportMissingImports] | ||
|
||
def get_stat(node: LogicalPlan): | ||
def get_name(node: Optional[LogicalPlan]) -> str: | ||
def get_name(node: Optional[LogicalPlan]) -> str: # pragma: no cover | ||
if node is None: | ||
return "EMPTY_SOURCE_PLAN" # pragma: no cover | ||
addr = hex(id(node)) | ||
name = str(type(node)).split(".")[-1].split("'")[0] | ||
return f"{name}({addr})" | ||
suffix = "" | ||
if isinstance(node, SnowflakeCreateTable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a comment here about what are the different printing used here |
||
# get the table name from the full qualified name | ||
table_name = node.table_name[-1].split(".")[-1] # pyright: ignore | ||
suffix = f" :: {table_name}" | ||
if isinstance(node, WithQueryBlock): | ||
# get the CTE identifier excluding SNOWPARK_TEMP_CTE_ | ||
suffix = f" :: {node.name[18:]}" | ||
|
||
return f"{name}({addr}){suffix}" | ||
|
||
name = get_name(node) | ||
if isinstance(node, SnowflakePlan): | ||
|
@@ -411,20 +427,44 @@ def get_name(node: Optional[LogicalPlan]) -> str: | |
if node.offset: | ||
properties.append("Offset") # pragma: no cover | ||
name = f"{name} :: ({'| '.join(properties)})" | ||
elif isinstance(node, SelectableEntity): | ||
# get the table name from the full qualified name | ||
name = f"{name} :: ({node.entity.name.split('.')[-1]})" | ||
|
||
def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover | ||
if isinstance(node, Selectable): | ||
return node.sql_query | ||
if isinstance(node, SnowflakePlan): | ||
return node.queries[-1].sql | ||
return "" | ||
|
||
score = get_complexity_score(node) | ||
num_ref_ctes = "nil" | ||
if isinstance(node, (SnowflakePlan, Selectable)): | ||
num_ref_ctes = len(node.referenced_ctes) | ||
sql_text = "" | ||
if isinstance(node, Selectable): | ||
sql_text = node.sql_query | ||
elif isinstance(node, SnowflakePlan): | ||
sql_text = node.queries[-1].sql | ||
sql_text = get_sql_text(node) | ||
sql_size = len(sql_text) | ||
ref_ctes = None | ||
if isinstance(node, (SnowflakePlan, Selectable)): | ||
ref_ctes = list( | ||
map( | ||
lambda node, cnt: f"{node.name[18:]}:{cnt}", | ||
node.referenced_ctes.keys(), | ||
node.referenced_ctes.values(), | ||
) | ||
) | ||
for with_query_block in node.referenced_ctes: # pragma: no cover | ||
sql_size += len(get_sql_text(with_query_block.children[0])) | ||
sql_preview = sql_text[:50] | ||
|
||
return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}" | ||
return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}" | ||
|
||
def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover | ||
if isinstance(node, WithQueryBlock): | ||
return True | ||
if isinstance(node, SnowflakePlan): | ||
return is_with_query_block(node.source_plan) | ||
if isinstance(node, SelectSnowflakePlan): | ||
return is_with_query_block(node.snowflake_plan) | ||
|
||
return False | ||
|
||
g = graphviz.Graph(format="png") | ||
|
||
|
@@ -435,11 +475,18 @@ def get_name(node: Optional[LogicalPlan]) -> str: | |
for node in curr_level: | ||
node_id = hex(id(node)) | ||
color = "lightblue" if node._is_valid_for_replacement else "red" | ||
g.node(node_id, get_stat(node), color=color) | ||
fillcolor = "lightgray" if is_with_query_block(node) else "white" | ||
g.node( | ||
node_id, | ||
get_stat(node), | ||
color=color, | ||
style="filled", | ||
fillcolor=fillcolor, | ||
) | ||
if isinstance(node, (Selectable, SnowflakePlan)): | ||
children = node.children_plan_nodes | ||
else: | ||
children = node.children | ||
children = node.children # pragma: no cover | ||
for child in children: | ||
child_id = hex(id(child)) | ||
edges.add((node_id, child_id)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -788,16 +788,24 @@ def test_large_query_breakdown_enabled_parameter(session, caplog): | |
|
||
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="requires graphviz") | ||
@pytest.mark.parametrize("enabled", [False, True]) | ||
def test_plotter(session, large_query_df, enabled): | ||
@pytest.mark.parametrize("plotting_score_threshold", [0, 10_000_000]) | ||
def test_plotter(large_query_df, enabled, plotting_score_threshold): | ||
original_plotter_enabled = os.environ.get("ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING") | ||
original_score_threshold = os.environ.get( | ||
"SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD" | ||
) | ||
try: | ||
os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"] = str(enabled) | ||
os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"] = str( | ||
plotting_score_threshold | ||
) | ||
tmp_dir = tempfile.gettempdir() | ||
|
||
with patch("graphviz.Graph.render") as mock_render: | ||
large_query_df.collect() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we perhaps add a comment explaining that the actual complexity for |
||
assert mock_render.called == enabled | ||
if not enabled: | ||
should_plot = enabled and (plotting_score_threshold == 0) | ||
assert mock_render.called == should_plot | ||
if not should_plot: | ||
return | ||
|
||
assert mock_render.call_count == 5 | ||
|
@@ -819,3 +827,9 @@ def test_plotter(session, large_query_df, enabled): | |
] = original_plotter_enabled | ||
else: | ||
del os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"] | ||
if original_score_threshold is not None: | ||
os.environ[ | ||
"SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD" | ||
] = original_score_threshold | ||
else: | ||
del os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this Plotting threshold used for? seems it is used for restricting the complexity score? maybe call this SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD to be more clear