Skip to content

Commit

Permalink
Fix inspect issue of pandas udf in SP by using kwargs (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu authored Mar 23, 2022
1 parent 95ab331 commit c8d5e2a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,7 @@ def add_df_one(df: pd.DataFrame) -> pd.Series:
replace=replace,
parallel=parallel,
max_batch_size=max_batch_size,
_from_pandas_udf_function=True,
)
else:
return session.udf.register(
Expand All @@ -2162,6 +2163,7 @@ def add_df_one(df: pd.DataFrame) -> pd.Series:
replace=replace,
parallel=parallel,
max_batch_size=max_batch_size,
_from_pandas_udf_function=True,
)


Expand Down
9 changes: 1 addition & 8 deletions src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved.
#
"""User-defined functions (UDFs) in Snowpark."""
import inspect
from types import ModuleType
from typing import Callable, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -429,12 +428,6 @@ def register(
TempObjectType.FUNCTION, name, is_permanent, stage_location, parallel
)

# whether called from pandas_udf
caller_frame = inspect.getouterframes(inspect.currentframe())
from_pandas_udf_function = (
len(caller_frame) > 1 and caller_frame[1].function == "_pandas_udf"
)

# register udf
return self.__do_register_udf(
func,
Expand All @@ -447,7 +440,7 @@ def register(
replace,
parallel,
kwargs.get("max_batch_size"),
from_pandas_udf_function,
kwargs.get("_from_pandas_udf_function", False),
)

def register_from_file(
Expand Down

0 comments on commit c8d5e2a

Please sign in to comment.