diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 721d5bf6570..a9953c42820 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -2148,6 +2148,7 @@ def add_df_one(df: pd.DataFrame) -> pd.Series: replace=replace, parallel=parallel, max_batch_size=max_batch_size, + _from_pandas_udf_function=True, ) else: return session.udf.register( @@ -2162,6 +2163,7 @@ def add_df_one(df: pd.DataFrame) -> pd.Series: replace=replace, parallel=parallel, max_batch_size=max_batch_size, + _from_pandas_udf_function=True, ) diff --git a/src/snowflake/snowpark/udf.py b/src/snowflake/snowpark/udf.py index 755002efab5..8c10a84cc35 100644 --- a/src/snowflake/snowpark/udf.py +++ b/src/snowflake/snowpark/udf.py @@ -4,7 +4,6 @@ # Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. # """User-defined functions (UDFs) in Snowpark.""" -import inspect from types import ModuleType from typing import Callable, Iterable, List, Optional, Tuple, Union @@ -429,12 +428,6 @@ def register( TempObjectType.FUNCTION, name, is_permanent, stage_location, parallel ) - # whether called from pandas_udf - caller_frame = inspect.getouterframes(inspect.currentframe()) - from_pandas_udf_function = ( - len(caller_frame) > 1 and caller_frame[1].function == "_pandas_udf" - ) - # register udf return self.__do_register_udf( func, @@ -447,7 +440,7 @@ def register( replace, parallel, kwargs.get("max_batch_size"), - from_pandas_udf_function, + kwargs.get("_from_pandas_udf_function", False), ) def register_from_file(