Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1878372: Fix analyzer access across threads #2912

Merged
merged 11 commits into from
Jan 27, 2025
5 changes: 2 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
import uuid
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Union

from snowflake.connector import IntegrityError

Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
self.plan_builder = SnowflakePlanBuilder(self.session)
self.generated_alias_maps = {}
self.subquery_plans = []
self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None
self.alias_maps_to_use: Dict[uuid.UUID, str] = {}

def analyze(
self,
Expand Down Expand Up @@ -368,7 +368,6 @@ def analyze(
return expr.sql

if isinstance(expr, Attribute):
assert self.alias_maps_to_use is not None
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
return quote_name(name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def cache_metadata_if_select_statement(

if (
isinstance(source_plan, SelectStatement)
and source_plan.analyzer.session.reduce_describe_query_enabled
and source_plan._session.reduce_describe_query_enabled
):
source_plan._attributes = metadata.attributes
# When source_plan doesn't have a projection, it's a simple `SELECT * from ...`,
Expand Down
46 changes: 36 additions & 10 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,14 @@ def __init__(
] = None, # Use Any because it's recursive.
) -> None:
super().__init__()
self.analyzer = analyzer
# With multi-threading support, each thread has its own analyzer which can be
# accessed through session object. Therefore, we need to store the session in
# the Selectable object and use the session to access the appropriate analyzer
# for current thread.
self._session = analyzer.session
# We create this internal object to be used for setting query generator during
# the optimization stage
self._analyzer = None
self.pre_actions: Optional[List["Query"]] = None
self.post_actions: Optional[List["Query"]] = None
self.flatten_disabled: bool = False
Expand All @@ -243,6 +250,25 @@ def __init__(
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None
self._encoded_node_id_with_query: Optional[str] = None

@property
def analyzer(self) -> "Analyzer":
"""Get the analyzer for used for the current thread"""
if self._analyzer is None:
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved
return self._session._analyzer
return self._analyzer

@analyzer.setter
def analyzer(self, value: "Analyzer") -> None:
"""For query optimization stage, we need to replace the analyzer with a query generator which
is aware of schema for the final plan and can compile WithQueryBlocks. Therefore we update the
setter to allow the analyzer to be set externally."""
if not self._is_valid_for_replacement:
raise ValueError(
"Cannot set analyzer for a Selectable that is not valid for replacement"
)

self._analyzer = value
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved

@property
@abstractmethod
def sql_query(self) -> str:
Expand All @@ -258,7 +284,7 @@ def encoded_node_id_with_query(self) -> str:
two selectable node with same queries. This is currently used by repeated subquery
elimination to detect two nodes with same query, please use it with careful.
"""
with self.analyzer.session._plan_lock:
with self._session._plan_lock:
if self._encoded_node_id_with_query is None:
self._encoded_node_id_with_query = encode_node_id_with_query(self)
return self._encoded_node_id_with_query
Expand Down Expand Up @@ -310,7 +336,7 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
queries,
schema_query,
post_actions=self.post_actions,
session=self.analyzer.session,
session=self._session,
expr_to_alias=self.expr_to_alias,
df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name,
source_plan=self,
Expand All @@ -328,7 +354,7 @@ def plan_state(self) -> Dict[PlanState, Any]:

@property
def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]:
with self.analyzer.session._plan_lock:
with self._session._plan_lock:
if self._cumulative_node_complexity is None:
self._cumulative_node_complexity = sum_node_complexities(
self.individual_node_complexity,
Expand Down Expand Up @@ -361,7 +387,7 @@ def column_states(self) -> ColumnStateDict:
Refer to class ColumnStateDict.
"""
if self._column_states is None:
if self.analyzer.session.reduce_describe_query_enabled:
if self._session.reduce_describe_query_enabled:
# data types are not needed in SQL simplifier, so we
# just create dummy data types here.
column_attrs = [
Expand Down Expand Up @@ -512,7 +538,7 @@ def __init__(
self.pre_actions[0].query_id_place_holder
)
self._schema_query = analyzer_utils.schema_value_statement(
analyze_attributes(sql, self.analyzer.session)
analyze_attributes(sql, self._session)
) # Change to subqueryable schema query so downstream query plan can describe the SQL
self._query_param = None
else:
Expand Down Expand Up @@ -1165,7 +1191,7 @@ def filter(self, col: Expression) -> "SelectStatement":
new = SelectStatement(
from_=self.to_subqueryable(), where=col, analyzer=self.analyzer
)
if self.analyzer.session.reduce_describe_query_enabled:
if self._session.reduce_describe_query_enabled:
new._attributes = self._attributes

return new
Expand Down Expand Up @@ -1200,7 +1226,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
order_by=cols,
analyzer=self.analyzer,
)
if self.analyzer.session.reduce_describe_query_enabled:
if self._session.reduce_describe_query_enabled:
new._attributes = self._attributes

return new
Expand Down Expand Up @@ -1284,7 +1310,7 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement":
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new._merge_projection_complexity_with_subquery = False
if self.analyzer.session.reduce_describe_query_enabled:
if self._session.reduce_describe_query_enabled:
new._attributes = self._attributes

return new
Expand Down Expand Up @@ -1604,7 +1630,7 @@ def can_select_projection_complexity_be_merged(
on top of subquery.
subquery: the subquery where the current select is performed on top of
"""
if not subquery.analyzer.session._large_query_breakdown_enabled:
if not subquery._session._large_query_breakdown_enabled:
return False

# only merge of nested select statement is supported, and subquery must be
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/mock/_nop_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def attributes(self):
class NopSelectableEntity(MockSelectableEntity):
@property
def attributes(self):
return resolve_attributes(self.entity_plan, session=self.analyzer.session)
return resolve_attributes(self.entity_plan, session=self._session)


class NopAnalyzer(MockAnalyzer):
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,15 +986,15 @@ def execute_mock_plan(
res_df = execute_mock_plan(
MockExecutionPlan(
first_operand.selectable,
source_plan.analyzer.session,
source_plan._session,
),
expr_to_alias,
)
for i in range(1, len(source_plan.set_operands)):
operand = source_plan.set_operands[i]
operator = operand.operator
cur_df = execute_mock_plan(
MockExecutionPlan(operand.selectable, source_plan.analyzer.session),
MockExecutionPlan(operand.selectable, source_plan._session),
expr_to_alias,
)
if len(res_df.columns) != len(cur_df.columns):
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/snowpark/mock/_select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
analyzer: "Analyzer",
) -> None:
super().__init__()
self.analyzer = analyzer
self._session = analyzer.session
self.pre_actions = None
self.post_actions = None
self.flatten_disabled: bool = False
Expand All @@ -76,6 +76,10 @@ def __init__(
str, Dict[str, str]
] = defaultdict(dict)

@property
def analyzer(self) -> "Analyzer":
return self._session._analyzer

@property
def sql_query(self) -> str:
"""Returns the sql query of this Selectable logical plan."""
Expand All @@ -97,7 +101,7 @@ def execution_plan(self):
from snowflake.snowpark.mock._plan import MockExecutionPlan

if self._execution_plan is None:
self._execution_plan = MockExecutionPlan(self, self.analyzer.session)
self._execution_plan = MockExecutionPlan(self, self._session)
return self._execution_plan

@property
Expand Down
1 change: 1 addition & 0 deletions tests/integ/compiler/test_query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def reset_node(node: LogicalPlan, query_generator: QueryGenerator) -> None:
def reset_selectable(selectable_node: Selectable) -> None:
# reset the analyzer to use the current query generator instance to
# ensure the new query generator is used during the resolve process
selectable_node._is_valid_for_replacement = True
selectable_node.analyzer = query_generator
if not isinstance(selectable_node, (SelectSnowflakePlan, SelectSQL)):
selectable_node._snowflake_plan = None
Expand Down
72 changes: 66 additions & 6 deletions tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,11 +660,12 @@ def change_config_value(session_):
session_.conf.set(config, value)

caplog.clear()
change_config_value(threadsafe_session)
assert (
f"You might have more than one threads sharing the Session object trying to update {config}"
not in caplog.text
)
if threading.active_count() == 1:
change_config_value(threadsafe_session)
assert (
f"You might have more than one threads sharing the Session object trying to update {config}"
not in caplog.text
)

with caplog.at_level(logging.WARNING):
with ThreadPoolExecutor(max_workers=5) as executor:
Expand Down Expand Up @@ -857,9 +858,13 @@ def process_data(df_, thread_id):
).csv(f"{stage_with_prefix}/{filename}")

with threadsafe_session.query_history() as history:
futures = []
with ThreadPoolExecutor(max_workers=5) as executor:
for i in range(10):
executor.submit(process_data, df, i)
futures.append(executor.submit(process_data, df, i))

for future in as_completed(futures):
future.result()

queries_sent = [query.sql_text for query in history.queries]

Expand Down Expand Up @@ -953,3 +958,58 @@ def call_critical_lazy_methods(df_):
# called only once and the cached result should be used for the rest of
# the calls.
mock_find_duplicate_subtrees.assert_called_once()


def create_and_join(_session):
df1 = _session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df2 = _session.create_dataframe([[1, 7], [3, 8]], schema=["a", "b"])
df3 = df1.join(df2)
expected = [Row(1, 2, 1, 7), Row(1, 2, 3, 8), Row(3, 4, 1, 7), Row(3, 4, 3, 8)]
Utils.check_answer(df3, expected)
return [df1, df2, df3]


def join_again(df1, df2, df3):
df3 = df1.join(df2).select(df1.a)
expected = [Row(1, 2, 1, 7), Row(1, 2, 3, 8), Row(3, 4, 1, 7), Row(3, 4, 3, 8)]
Utils.check_answer(df3, expected)


def create_aliased_df(_session):
df1 = _session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df2 = df1.join(df1.filter(col("a") == 1)).select(df1.a.alias("a1"))
Utils.check_answer(df2, [Row(A1=1), Row(A1=3)])
return [df2]


def select_aliased_col(df2):
df2 = df2.select(df2.a1)
Utils.check_answer(df2, [Row(A1=1), Row(A1=3)])


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="SNOW-1373887: Support basic diamond shaped joins in Local Testing",
run=False,
)
@pytest.mark.parametrize(
"f1,f2", [(create_and_join, join_again), (create_aliased_df, select_aliased_col)]
)
def test_SNOW_1878372(threadsafe_session, f1, f2):
class ReturnableThread(threading.Thread):
def __init__(self, target, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._target = target
self.result = None

def run(self):
if self._target is not None:
self.result = self._target(*self._args, **self._kwargs)

t1 = ReturnableThread(target=f1, args=(threadsafe_session,))
t1.start()
t1.join()

t2 = ReturnableThread(target=f2, args=tuple(t1.result))
t2.start()
t2.join()
6 changes: 5 additions & 1 deletion tests/unit/compiler/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@
from snowflake.snowpark._internal.compiler.large_query_breakdown import (
LargeQueryBreakdown,
)
from snowflake.snowpark.session import Session

dummy_session = mock.create_autospec(Session)
dummy_analyzer = mock.create_autospec(Analyzer)
dummy_analyzer.session = dummy_session
empty_logical_plan = LogicalPlan()
empty_expression = Expression()
empty_selectable = SelectSQL("dummy_query", analyzer=mock.create_autospec(Analyzer))
empty_selectable = SelectSQL("dummy_query", analyzer=dummy_analyzer)


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/compiler/test_replace_child_and_update_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ def test_select_statement(
new_replaced_plan = plan.children_plan_nodes[0]
assert isinstance(new_replaced_plan, SelectSnowflakePlan)
assert new_replaced_plan._snowflake_plan.source_plan == new_plan
assert new_replaced_plan.analyzer == mock_query_generator
# new_replaced_plan is created with QueryGenerator.to_selectable
assert new_replaced_plan.analyzer == mock_analyzer
Comment on lines -458 to +459
Copy link
Contributor

@sfc-gh-aling sfc-gh-aling Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change to pass the new "setter" check?
is mock_query_generator/mock_analyzer used in any of the test code after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_replaced_plan is created by calling within QueryGenerator.to_selectable which uses self as an analyzer.

Before this change Selectable.analyzer would have returned the analyzer used to create the selectable i.e. mock_query_generator.

After this change, Selectable.analyzer will return the appropriate session.analyzer. Therefore, we need to update this test.


post_actions = [Query("drop table if exists table_name")]
new_replaced_plan.post_actions = post_actions
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_find_duplicate_subtrees(test_case):
assert repeated_node_complexity == expected_repeated_node_complexity


def test_encode_node_id_with_query_select_sql(mock_analyzer):
def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as jamison's question. Linking old response: #2912 (comment)

sql_text = "select 1 as a, 2 as b"
select_sql_node = SelectSQL(
sql=sql_text,
Expand Down
Loading
Loading