Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup how args and kwargs are passed in _fast_slow_function_call #16266

Draft
wants to merge 2 commits into
base: branch-25.02
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,6 @@ def _fast_slow_function_call(
func: Callable,
/,
*args,
**kwargs,
) -> Any:
"""
Call `func` with all `args` and `kwargs` converted to their
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update this docstring to note that args contains func's args and kwargs?

Expand All @@ -893,8 +892,8 @@ def _fast_slow_function_call(
color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_FAST"],
domain="cudf_pandas",
):
fast_args, fast_kwargs = _fast_arg(args), _fast_arg(kwargs)
result = func(*fast_args, **fast_kwargs)
fast_args = _fast_arg(args)
result = func(*fast_args)
if result is NotImplemented:
# try slow path
raise Exception()
Expand All @@ -906,12 +905,9 @@ def _fast_slow_function_call(
color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_SLOW"],
domain="cudf_pandas",
):
slow_args, slow_kwargs = (
_slow_arg(args),
_slow_arg(kwargs),
)
slow_args = (_slow_arg(args),)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
slow_args = (_slow_arg(args),)
slow_args = _slow_arg(args)

with disable_module_accelerator():
slow_result = func(*slow_args, **slow_kwargs)
slow_result = func(*slow_args)
except Exception as e:
warnings.warn(
"The result from pandas could not be computed. "
Expand All @@ -936,10 +932,10 @@ def _fast_slow_function_call(
color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_SLOW"],
domain="cudf_pandas",
):
slow_args, slow_kwargs = _slow_arg(args), _slow_arg(kwargs)
slow_args = _slow_arg(args)
with disable_module_accelerator():
result = func(*slow_args, **slow_kwargs)
return _maybe_wrap_result(result, func, *args, **kwargs), fast
result = func(*slow_args)
return _maybe_wrap_result(result, func, *args), fast


def _transform_arg(
Expand Down Expand Up @@ -1054,7 +1050,7 @@ def _slow_arg(arg: Any) -> Any:
return _transform_arg(arg, "_fsproxy_slow", seen)


def _maybe_wrap_result(result: Any, func: Callable, /, *args, **kwargs) -> Any:
def _maybe_wrap_result(result: Any, func: Callable, /, *args) -> Any:
"""
Wraps "result" in a fast-slow proxy if is a "proxiable" object.
"""
Expand All @@ -1063,7 +1059,7 @@ def _maybe_wrap_result(result: Any, func: Callable, /, *args, **kwargs) -> Any:
return typ._fsproxy_wrap(result, func)
elif _is_intermediate_type(result):
typ = get_intermediate_type_map()[type(result)]
return typ._fsproxy_wrap(result, method_chain=(func, args, kwargs))
return typ._fsproxy_wrap(result, method_chain=(func, args))
elif _is_final_class(result):
return get_final_type_map()[result]
elif isinstance(result, list):
Expand Down
Loading