Skip to content

Commit

Permalink
fix: descend logic and lru_cache support
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Jun 17, 2024
1 parent 0b4846a commit 4be7e71
Showing 1 changed file with 51 additions and 10 deletions.
61 changes: 51 additions & 10 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import inspect
from contextvars import ContextVar
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from functools import wraps, _lru_cache_wrapper
from typing import Any, Callable
from result import Result, Ok, Err

Expand All @@ -13,6 +14,22 @@
logger.disable("maxray")


def _set_logging(enabled: bool):
if enabled:
logger.enable("maxray")
else:
logger.disable("maxray")


@contextmanager
def _with_logging():
try:
_set_logging(True)
yield
finally:
_set_logging(False)


def transform(writer):
"""
Decorator that rewrites the source code of a function to wrap every single expression by passing it through the `writer(expr, ctx)` callable.
Expand All @@ -37,14 +54,19 @@ def xray(walker, **kwargs):


_GLOBAL_SKIP_MODULES = {
"builtins", # duh... object.__init__ is a wrapper_descriptor
"abc", # excessive inheritance and super calls in scikit-learn
"inspect", # don't want to screw this module up
"pathlib", # internally used in transform for checking source file exists
"re", # internals of regexp have a lot of uninteresting step methods
"copy", # pytorch spends so much time in here
"functools", # partialmethod causes mis-bound `self` in TQDM
"typing",
"importlib",
"ctypes",
"loguru", # used internally in transform - accidental patching will cause inf recursion
"uuid", # used for generating _MAXRAY_TRANSFORM_ID
"maxray",
}
# TODO: probably just skip the entire Python standard library...

Expand All @@ -65,6 +87,17 @@ class W_erHook:

_MAXRAY_REGISTERED_HOOKS: list[W_erHook] = []


def descend_allowed(x, ctx: NodeContext):
num_active_hooks = 0
for hook in _MAXRAY_REGISTERED_HOOKS:
if hook.active_call_state.get():
num_active_hooks += 1
if not hook.descend_predicate(x, ctx):
return False
return num_active_hooks > 0


_GLOBAL_WRITER_ACTIVE_FLAG = ContextVar("writer_active (global)", default=False)

# We don't want to recompile the same function over and over - so our cache needs to be global
Expand All @@ -84,7 +117,11 @@ def callable_allowed_for_transform(x, ctx: NodeContext):
and callable(x)
and callable(getattr(x, "__hash__", None))
and getattr(type(x), "__module__", None) not in {"ctypes"}
and (inspect.isfunction(x) or inspect.ismethod(x))
and (
inspect.isfunction(x)
or inspect.ismethod(x)
or isinstance(x, _lru_cache_wrapper)
)
)


Expand All @@ -94,6 +131,7 @@ def instance_init_allowed_for_transform(x, ctx: NodeContext):
"""
return (
type(x) is type
and getattr(x, "__module__", None) not in {"ctypes"}
and hasattr(x, "__init__")
and not hasattr(x, "_MAXRAY_TRANSFORMED")
)
Expand All @@ -103,18 +141,18 @@ def instance_call_allowed_for_transform(x, ctx: NodeContext):
"""
Decides whether the __call__ method can be transformed.
"""
if getattr(x, "__module__", None) in _GLOBAL_SKIP_MODULES:
return False

return (
type(x) is type
and getattr(x, "__module__", None) not in {"ctypes"}
and hasattr(x, "__call__")
and not hasattr(x, "_MAXRAY_TRANSFORMED")
)


def _maxray_walker_handler(x, ctx: NodeContext):
# We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell
if _GLOBAL_WRITER_ACTIVE_FLAG.get():
return x

# 1. logic to recursively patch callables
# 1a. special-case callables: __init__ and __call__
if instance_init_allowed_for_transform(x, ctx):
Expand All @@ -125,6 +163,8 @@ def _maxray_walker_handler(x, ctx: NodeContext):
case Ok(init_patch):
logger.info(f"Patching __init__ for class {x}")
setattr(x, "__init__", init_patch)
# case Err(bad):
# logger.error(bad)

elif instance_call_allowed_for_transform(x, ctx):
# TODO: should we somehow delay doing this until before an actual call?
Expand All @@ -134,16 +174,15 @@ def _maxray_walker_handler(x, ctx: NodeContext):
case Ok(call_patch):
logger.info(f"Patching __call__ for class {x}")
setattr(x, "__call__", call_patch)
# case Err(bad):
# logger.error(bad)

# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
elif callable_allowed_for_transform(x, ctx):
# We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload
if x in _MAXRAY_FN_CACHE:
x = _MAXRAY_FN_CACHE[x]
elif not any(
hook.active_call_state.get() and hook.descend_predicate(x, ctx)
for hook in _MAXRAY_REGISTERED_HOOKS
):
elif not descend_allowed(x, ctx):
# user-defined filters for which nodes (not) to descend into
pass
else:
Expand Down Expand Up @@ -288,6 +327,8 @@ def fn_with_context_update(*args, **kwargs):
ACTIVE_FLAG.reset(prev_token)

fn_with_context_update._MAXRAY_TRANSFORMED = True
# TODO: set correctly everywhere
# fn_with_context_update._MAXRAY_TRANSFORM_ID = ...
return fn_with_context_update

return recursive_transform

0 comments on commit 4be7e71

Please sign in to comment.