Skip to content

Commit

Permalink
Merge pull request #1 from blepabyte/06-10-fix_descend_logic_and_lru_…
Browse files Browse the repository at this point in the history
…cache_support

fix: descend logic and lru_cache support
  • Loading branch information
blepabyte authored Jun 17, 2024
2 parents 0b4846a + 87ce1ba commit 9d0fbde
Showing 1 changed file with 59 additions and 14 deletions.
73 changes: 59 additions & 14 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import inspect
from contextvars import ContextVar
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from functools import wraps, _lru_cache_wrapper
from typing import Any, Callable
from result import Result, Ok, Err

Expand All @@ -13,6 +14,22 @@
logger.disable("maxray")


def _set_logging(enabled: bool):
if enabled:
logger.enable("maxray")
else:
logger.disable("maxray")


@contextmanager
def _with_logging():
try:
_set_logging(True)
yield
finally:
_set_logging(False)


def transform(writer):
"""
Decorator that rewrites the source code of a function to wrap every single expression by passing it through the `writer(expr, ctx)` callable.
Expand All @@ -37,14 +54,19 @@ def xray(walker, **kwargs):


_GLOBAL_SKIP_MODULES = {
"builtins", # duh... object.__init__ is a wrapper_descriptor
"abc", # excessive inheritance and super calls in scikit-learn
"inspect", # don't want to screw this module up
"pathlib", # internally used in transform for checking source file exists
"re", # internals of regexp have a lot of uninteresting step methods
"copy", # pytorch spends so much time in here
"functools", # partialmethod causes mis-bound `self` in TQDM
"typing",
"importlib",
"ctypes",
"loguru", # used internally in transform - accidental patching will cause inf recursion
"uuid", # used for generating _MAXRAY_TRANSFORM_ID
"maxray",
}
# TODO: probably just skip the entire Python standard library...

Expand All @@ -65,6 +87,17 @@ class W_erHook:

_MAXRAY_REGISTERED_HOOKS: list[W_erHook] = []


def descend_allowed(x, ctx: NodeContext):
num_active_hooks = 0
for hook in _MAXRAY_REGISTERED_HOOKS:
if hook.active_call_state.get():
num_active_hooks += 1
if not hook.descend_predicate(x, ctx):
return False
return num_active_hooks > 0


_GLOBAL_WRITER_ACTIVE_FLAG = ContextVar("writer_active (global)", default=False)

# We don't want to recompile the same function over and over - so our cache needs to be global
Expand All @@ -84,7 +117,11 @@ def callable_allowed_for_transform(x, ctx: NodeContext):
and callable(x)
and callable(getattr(x, "__hash__", None))
and getattr(type(x), "__module__", None) not in {"ctypes"}
and (inspect.isfunction(x) or inspect.ismethod(x))
and (
inspect.isfunction(x)
or inspect.ismethod(x)
or isinstance(x, _lru_cache_wrapper)
)
)


Expand All @@ -94,6 +131,7 @@ def instance_init_allowed_for_transform(x, ctx: NodeContext):
"""
return (
type(x) is type
and getattr(x, "__module__", None) not in {"ctypes"}
and hasattr(x, "__init__")
and not hasattr(x, "_MAXRAY_TRANSFORMED")
)
Expand All @@ -103,47 +141,47 @@ def instance_call_allowed_for_transform(x, ctx: NodeContext):
"""
Decides whether the __call__ method can be transformed.
"""
if getattr(x, "__module__", None) in _GLOBAL_SKIP_MODULES:
return False

return (
type(x) is type
and getattr(x, "__module__", None) not in {"ctypes"}
and hasattr(x, "__call__")
and not hasattr(x, "_MAXRAY_TRANSFORMED")
)


def _maxray_walker_handler(x, ctx: NodeContext):
# We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell
if _GLOBAL_WRITER_ACTIVE_FLAG.get():
return x

# 1. logic to recursively patch callables
# 1a. special-case callables: __init__ and __call__
if instance_init_allowed_for_transform(x, ctx):
# TODO: should we somehow delay doing this until before an actual call?
match recompile_fn_with_transform(
x.__init__, _maxray_walker_handler, special_use_instance_type=x
):
case Ok(init_patch):
logger.info(f"Patching __init__ for class {x}")
logger.debug(f"Patching __init__ for class {x}")
setattr(x, "__init__", init_patch)
# TODO: consolidate error handling and reporting
# case Err(bad):
# logger.error(bad)

elif instance_call_allowed_for_transform(x, ctx):
# TODO: should we somehow delay doing this until before an actual call?
match recompile_fn_with_transform(
x.__call__, _maxray_walker_handler, special_use_instance_type=x
):
case Ok(call_patch):
logger.info(f"Patching __call__ for class {x}")
logger.debug(f"Patching __call__ for class {x}")
setattr(x, "__call__", call_patch)
# case Err(bad):
# logger.error(bad)

# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
elif callable_allowed_for_transform(x, ctx):
# We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload
if x in _MAXRAY_FN_CACHE:
x = _MAXRAY_FN_CACHE[x]
elif not any(
hook.active_call_state.get() and hook.descend_predicate(x, ctx)
for hook in _MAXRAY_REGISTERED_HOOKS
):
elif not descend_allowed(x, ctx):
# user-defined filters for which nodes (not) to descend into
pass
else:
Expand Down Expand Up @@ -185,6 +223,11 @@ def _maxray_walker_handler(x, ctx: NodeContext):
# Errors in functions that have been recursively compiled are less important
logger.warning(f"Failed to transform in walker handler: {e}")

# We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell
# This happens *after* checking and patching callables to still allow for explicitly patching a callable/method by calling this handler
if _GLOBAL_WRITER_ACTIVE_FLAG.get():
return x

# 2. run the active hooks
global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True)
try:
Expand Down Expand Up @@ -288,6 +331,8 @@ def fn_with_context_update(*args, **kwargs):
ACTIVE_FLAG.reset(prev_token)

fn_with_context_update._MAXRAY_TRANSFORMED = True
# TODO: set correctly everywhere
# fn_with_context_update._MAXRAY_TRANSFORM_ID = ...
return fn_with_context_update

return recursive_transform

0 comments on commit 9d0fbde

Please sign in to comment.