Skip to content

Commit

Permalink
Merge pull request rapidsai#160 from shwina/fix-getting-attrs-from-mo…
Browse files Browse the repository at this point in the history
…dules-defining-getattr

Always find the fast types/function corresponding to slow types/functions
  • Loading branch information
shwina authored Nov 6, 2023
2 parents 5b32287 + 9a5e7d8 commit f8af756
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 46 deletions.
86 changes: 40 additions & 46 deletions python/cudf/cudf/pandas/module_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
import importlib.abc
import importlib.machinery
import os
import pathlib
import sys
import threading
import warnings
from abc import abstractmethod
from importlib._bootstrap import _ImportLockContext as ImportLock
from types import ModuleType
from typing import Any, ContextManager, Dict, List, NamedTuple, Tuple
from typing import Any, ContextManager, Dict, List, NamedTuple

from typing_extensions import Self

from .fast_slow_proxy import (
_FunctionProxy,
_is_final_class,
_is_function_or_method,
_Unusable,
get_final_type_map,
Expand Down Expand Up @@ -61,24 +61,6 @@ def rename_root_module(module: str, root: str, new_root: str) -> str:
return module


def sorted_module_items(mod: ModuleType) -> List[Tuple[str, Any]]:
"""
Return the items of a module sorted such that submodules
appear last.
"""
# It is advantageous to sort the module items so that submodules
# appear last: (GH:127)
# Assume __dir__ contains all objects accessible under mod.__getattr__
# GH 403
items = set(mod.__dict__.keys()).union(set(mod.__dir__()))
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
return sorted(
((item, getattr(mod, item)) for item in items),
key=lambda x: isinstance(x[1], ModuleType),
)


class DeducedMode(NamedTuple):
use_fast_lib: bool
slow_lib: str
Expand Down Expand Up @@ -315,18 +297,23 @@ def _wrap_attribute(
-------
Wrapped attribute
"""
wrapped_attr: Any
# TODO: what else should we make sure not to get from the fast
# library?
if name in {"__all__", "__dir__", "__file__", "__doc__"}:
return slow_attr
if self.fast_lib == self.slow_lib:
wrapped_attr = slow_attr
elif self.fast_lib == self.slow_lib:
# no need to create a fast-slow wrapper
return slow_attr
try:
wrapped_attr = slow_attr
if any(
[
slow_attr in get_registered_functions(),
slow_attr in get_final_type_map(),
slow_attr in get_intermediate_type_map(),
]
):
# attribute already registered in self._wrapped_objs
return self._wrapped_objs[slow_attr]
except (KeyError, TypeError):
pass
wrapped_attr: Any
if isinstance(slow_attr, ModuleType) and slow_attr.__name__.startswith(
self.slow_lib
):
Expand All @@ -335,21 +322,21 @@ def _wrap_attribute(
# name with "{self.mod_name}"
# now, attempt to import the wrapped module, which will
# recursively wrap all of its attributes:
wrapped_attr = importlib.import_module(
return importlib.import_module(
rename_root_module(
slow_attr.__name__, self.slow_lib, self.mod_name
)
)
elif _is_final_class(slow_attr):
wrapped_attr = get_final_type_map()[slow_attr]
elif _is_function_or_method(slow_attr):
if slow_attr in self._wrapped_objs:
if type(fast_attr) is _Unusable:
# we don't want to replace a wrapped object that
# has a usable fast object with a wrapped object
# with a an unusable fast object.
return self._wrapped_objs[slow_attr]
if _is_function_or_method(slow_attr):
wrapped_attr = _FunctionProxy(fast_attr, slow_attr)
else:
wrapped_attr = slow_attr
try:
self._wrapped_objs[slow_attr] = wrapped_attr
except TypeError:
pass
return wrapped_attr

@classmethod
Expand Down Expand Up @@ -472,17 +459,18 @@ def _populate_module(self, mod: ModuleType):
# package
real_attributes = {}
# The version that will be used outside denylist packages
wrapped_attributes = {}
for key, _ in sorted_module_items(slow_mod):
# Only copy attributes that don't already exist
for key in slow_mod.__dir__():
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
slow_attr = getattr(slow_mod, key)
fast_attr = getattr(fast_mod, key, _Unusable())
real_attributes[key] = slow_attr
wrapped_attributes[key] = self._wrap_attribute(
slow_attr, fast_attr, key
)
try:
wrapped_attr = self._wrap_attribute(slow_attr, fast_attr, key)
self._wrapped_objs[slow_attr] = wrapped_attr
except TypeError:
# slow_attr is not hashable
pass

# Our module has (basically) no static attributes and instead
# always delivers them dynamically where the behaviour is
Expand All @@ -493,7 +481,7 @@ def _populate_module(self, mod: ModuleType):
functools.partial(
self.getattr_real_or_wrapped,
real=real_attributes,
wrapped=wrapped_attributes,
wrapped_objs=self._wrapped_objs,
loader=self,
),
)
Expand Down Expand Up @@ -539,7 +527,7 @@ def getattr_real_or_wrapped(
name: str,
*,
real: Dict[str, Any],
wrapped: Dict[str, Any],
wrapped_objs,
loader: ModuleAccelerator,
) -> Any:
"""
Expand Down Expand Up @@ -576,13 +564,19 @@ def getattr_real_or_wrapped(
assert frame.f_back
calling_module = frame.f_back.f_code.co_filename
use_real = any(
calling_module.startswith(path) for path in loader._denylist
pathlib.PurePath(calling_module).is_relative_to(path)
for path in loader._denylist
)
location = real if use_real else wrapped
try:
return location[name]
if use_real:
return real[name]
else:
return wrapped_objs[real[name]]
except KeyError:
raise AttributeError(f"No attribute '{name}'")
except TypeError:
# real[name] is an unhashable type
return real[name]

@classmethod
def install(
Expand Down
12 changes: 12 additions & 0 deletions python/cudf/cudf_pandas_tests/test_cudf_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from numba import NumbaDeprecationWarning

from cudf.pandas import LOADED, Profiler
from cudf.pandas.fast_slow_proxy import _Unusable

if not LOADED:
raise ImportError("These tests must be run with cudf.pandas loaded")
Expand Down Expand Up @@ -1221,3 +1222,14 @@ def test_pandas_module_getattr_objects(idx_obj):
# Objects that are behind pandas.__getattr__ (version 1.5 specific)
idx = getattr(xpd, idx_obj)([1, 2, 3])
assert isinstance(idx, xpd.Index)


def test_concat_fast():
pytest.importorskip("cudf")

assert type(xpd.concat._fsproxy_fast) is not _Unusable


def test_func_namespace():
# note: this test is sensitive to Pandas' internal module layout
assert xpd.concat is xpd.core.reshape.concat.concat

0 comments on commit f8af756

Please sign in to comment.