From c3db2239f412880e406451e9e3a2662d4ead0266 Mon Sep 17 00:00:00 2001 From: Jamison Date: Mon, 16 Dec 2024 18:20:04 -0800 Subject: [PATCH] make feature flag thread safe --- src/snowflake/snowpark/_internal/type_utils.py | 4 ++-- src/snowflake/snowpark/context.py | 16 ++++++++++++++-- src/snowflake/snowpark/types.py | 16 ++++++++-------- tests/integ/scala/test_datatype_suite.py | 11 ++++++----- 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 1a0c17ee3af..0910a2a4aae 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -159,7 +159,7 @@ def convert_metadata_to_sp_type( [ StructField( field.name - if context._should_use_structured_type_semantics + if context._should_use_structured_type_semantics() else quote_name(field.name, keep_case=True), convert_metadata_to_sp_type(field, max_string_size), nullable=field.is_nullable, @@ -188,7 +188,7 @@ def convert_sf_to_sp_type( ) -> DataType: """Convert the Snowflake logical type to the Snowpark type.""" semi_structured_fill = ( - None if context._should_use_structured_type_semantics else StringType() + None if context._should_use_structured_type_semantics() else StringType() ) if column_type_name == "ARRAY": return ArrayType(semi_structured_fill) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 8bc86f928a1..a975a53a394 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import snowflake.snowpark +import threading _use_scoped_temp_objects = True @@ -21,8 +22,19 @@ _should_continue_registration: Optional[Callable[..., bool]] = None -# Global flag that determines if structured type semantics should be used -_should_use_structured_type_semantics = False +# Internal-only global flag that determines if structured type semantics should be used +_use_structured_type_semantics = False +_use_structured_type_semantics_lock = None + + +def _should_use_structured_type_semantics(): + global _use_structured_type_semantics + global _use_structured_type_semantics_lock + if _use_structured_type_semantics_lock is None: + _use_structured_type_semantics_lock = threading.RLock() + + with _use_structured_type_semantics_lock: + return _use_structured_type_semantics def get_active_session() -> "snowflake.snowpark.Session": diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index f78b35e1f97..6cf80757135 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -336,7 +336,7 @@ def __init__( element_type: Optional[DataType] = None, structured: Optional[bool] = None, ) -> None: - if context._should_use_structured_type_semantics: + if context._should_use_structured_type_semantics(): self.structured = ( structured if structured is not None else element_type is not None ) @@ -349,7 +349,7 @@ def __repr__(self) -> str: return f"ArrayType({repr(self.element_type) if self.element_type else ''})" def _as_nested(self) -> "ArrayType": - if not context._should_use_structured_type_semantics: + if not context._should_use_structured_type_semantics(): return self element_type = self.element_type if isinstance(element_type, (ArrayType, MapType, StructType)): @@ -396,7 +396,7 @@ def __init__( value_type: Optional[DataType] = None, structured: Optional[bool] = None, ) -> None: - if context._should_use_structured_type_semantics: + if context._should_use_structured_type_semantics(): if (key_type is None and value_type is not None) or ( key_type is not None and value_type is None ): @@ -423,7 +423,7 @@ def is_primitive(self): return False def _as_nested(self) -> "MapType": - if not context._should_use_structured_type_semantics: + if not context._should_use_structured_type_semantics(): return self value_type = self.value_type if isinstance(value_type, (ArrayType, MapType, StructType)): @@ -600,7 +600,7 @@ def __init__( @property def name(self) -> str: - if self._is_column or not context._should_use_structured_type_semantics: + if self._is_column or not context._should_use_structured_type_semantics(): return self.column_identifier.name else: return self._name @@ -615,7 +615,7 @@ def name(self, n: Union[ColumnIdentifier, str]) -> None: self.column_identifier = ColumnIdentifier(n) def _as_nested(self) -> "StructField": - if not context._should_use_structured_type_semantics: + if not context._should_use_structured_type_semantics(): return self datatype = self.datatype if isinstance(datatype, (ArrayType, MapType, StructType)): @@ -677,7 +677,7 @@ def __init__( fields: Optional[List["StructField"]] = None, structured: Optional[bool] = False, ) -> None: - if context._should_use_structured_type_semantics: + if context._should_use_structured_type_semantics(): self.structured = ( structured if structured is not None else fields is not None ) @@ -713,7 +713,7 @@ def add( return self def _as_nested(self) -> "StructType": - if not context._should_use_structured_type_semantics: + if not context._should_use_structured_type_semantics(): return self return StructType( [field._as_nested() for field in self.fields], self.structured diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 1004190b00e..48f64638b7c 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -167,10 +167,11 @@ def examples(structured_type_support): def structured_type_session(session, structured_type_support): if structured_type_support: with structured_types_enabled_session(session) as sess: - semantics_enabled = context._should_use_structured_type_semantics - context._should_use_structured_type_semantics = True - yield sess - context._should_use_structured_type_semantics = semantics_enabled + semantics_enabled = context._should_use_structured_type_semantics() + with context._use_structured_type_semantics_lock(): + context._use_structured_type_semantics = True + yield sess + context._use_structured_type_semantics = semantics_enabled else: yield session @@ -399,7 +400,7 @@ def test_structured_dtypes_select( ): query, expected_dtypes, expected_schema = examples df = _create_test_dataframe(structured_type_session, structured_type_support) - nested_field_name = "b" if context._should_use_structured_type_semantics else "B" + nested_field_name = "b" if context._should_use_structured_type_semantics() else "B" flattened_df = df.select( df.map["k1"].alias("value1"), df.obj["A"].alias("a"),