Skip to content

Commit

Permalink
fix: super() and method descriptor patching
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed May 28, 2024
1 parent de4e1f0 commit cbd7b48
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 39 deletions.
21 changes: 15 additions & 6 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@

from loguru import logger

# Avoid logspam for users of the library
logger.disable("maxray")


def transform(writer):
"""
Decorator that rewrites the source code of a function to wrap every single expression by passing it through the `writer(expr, ctx)` callable.
"""

def inner(fn):
match recompile_fn_with_transform(fn, writer):
match recompile_fn_with_transform(fn, writer, is_maxray_root=True):
case Ok(trans_fn):
return wraps(fn)(trans_fn)
case Err(err):
Expand All @@ -39,7 +42,11 @@ def xray(walker, **kwargs):
"pathlib", # internally used in transform for checking source file exists
"re", # internals of regexp have a lot of uninteresting step methods
"copy", # pytorch spends so much time in here
"typing",
"importlib",
"loguru", # used internally in transform - accidental patching will cause inf recursion
}
# TODO: probably just skip the entire Python standard library...


@dataclass
Expand Down Expand Up @@ -77,6 +84,7 @@ def callable_allowed_for_transform(x, ctx: NodeContext):
and callable(x)
and callable(getattr(x, "__hash__", None))
and getattr(type(x), "__module__", None) not in {"ctypes"}
and (inspect.isfunction(x) or inspect.ismethod(x))
)


Expand Down Expand Up @@ -118,7 +126,7 @@ def _maxray_walker_handler(x, ctx: NodeContext):
logger.info(f"Patching __init__ for class {x}")
setattr(x, "__init__", init_patch)

if instance_call_allowed_for_transform(x, ctx):
elif instance_call_allowed_for_transform(x, ctx):
# TODO: should we somehow delay doing this until before an actual call?
match recompile_fn_with_transform(
x.__call__, _maxray_walker_handler, special_use_instance_type=x
Expand All @@ -128,8 +136,8 @@ def _maxray_walker_handler(x, ctx: NodeContext):
setattr(x, "__call__", call_patch)

# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
if callable_allowed_for_transform(x, ctx):
# TODO: don't cache objects w/ __call__
elif callable_allowed_for_transform(x, ctx):
# We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload
if x in _MAXRAY_FN_CACHE:
x = _MAXRAY_FN_CACHE[x]
elif not any(
Expand Down Expand Up @@ -174,8 +182,8 @@ def _maxray_walker_handler(x, ctx: NodeContext):
case Err(e):
# Cache failures
_MAXRAY_FN_CACHE[x] = x
# Errors in functions that have been recursively compiled are unimportant
logger.trace(f"Failed to transform in walker handler: {e}")
# Errors in functions that have been recursively compiled are less important
logger.warning(f"Failed to transform in walker handler: {e}")

# 2. run the active hooks
global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True)
Expand Down Expand Up @@ -250,6 +258,7 @@ def recursive_transform(fn):
_maxray_walker_handler,
initial_scope=caller_locals,
pass_scope=pass_scope,
is_maxray_root=True,
):
case Ok(fn_transform):
pass
Expand Down
72 changes: 43 additions & 29 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
dedent_chars: int = 0,
record_call_counts: bool = True,
pass_locals_on_return: bool = False,
is_maxray_root: bool = False,
):
"""
If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None.
Expand All @@ -85,6 +86,7 @@ def __init__(
self.dedent_chars = dedent_chars
self.record_call_counts = record_call_counts
self.pass_locals_on_return = pass_locals_on_return
self.is_maxray_root = is_maxray_root

# the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions.
self.fn_count = 0
Expand Down Expand Up @@ -236,9 +238,29 @@ def visit_Return(self, node: ast.Return) -> Any:

def visit_Call(self, node):
source_pre = self.recover_source(node)

node_pre = deepcopy(node)

match node:
case ast.Call(func=ast.Name(id="super"), args=[]):
# HACK: STUPID HACKY HACK
# TODO: detect @classmethod properly
if "self" in self.fn_context.source:
node.args = [
ast.Name("__class__", ctx=ast.Load()),
ast.Name("self", ctx=ast.Load()),
]
else:
node.args = [
ast.Name("__class__", ctx=ast.Load()),
ast.Name("cls", ctx=ast.Load()),
]
node = ast.Call(
func=ast.Name("_MAXRAY_PATCH_MRO", ctx=ast.Load()),
args=[node],
keywords=[],
)
ast.fix_missing_locations(node)

node = self.generic_visit(node) # mutates

# the function/callable instance itself is observed by Name/Attribute/... nodes
Expand All @@ -249,32 +271,16 @@ def visit_Call(self, node):
node_pre,
)

id_key = f"{self.context_fn.__name__}/call/{source_pre}"
# Observes the *output* of the function
call_observer = ast.Call(
func=ast.Name(id=self.transform_fn.__name__, ctx=ast.Load()),
args=[node, ast.Constant({"id": id_key})],
keywords=[],
)
return call_observer

def visit_FunctionDef(self, node: ast.FunctionDef):
pre_node = deepcopy(node)
self.fn_count += 1

# Only overwrite the name of our "target function"
if self.fn_count == 1 and self.is_method():
node.name = f"{node.name}_{_METHOD_MANGLE_NAME}_{node.name}"

# TODO: add is_root arg (whether fn has directly been decorated with @*xray)
node.name = f"{node.name}_{self.instance_type}_{node.name}"

# Decorators are evaluated sequentially: decorators applied *before* our one (should?) get ignored while decorators applied *after* work correctly
is_transform_root = self.fn_count == 1 and any(
isinstance(decor, ast.Call)
and isinstance(decor.func, ast.Name)
and decor.func.id in {"maxray", "xray", "transform"}
for decor in node.decorator_list
)
is_transform_root = self.fn_count == 1 and self.is_maxray_root

if is_transform_root:
logger.info(
Expand Down Expand Up @@ -319,7 +325,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
self.fn_count += 1
# Only overwrite the name of our "target function"
if self.fn_count == 1 and self.is_method():
node.name = f"{node.name}_{_METHOD_MANGLE_NAME}_{node.name}"
node.name = f"{node.name}_{self.instance_type}_{node.name}"

is_transform_root = self.fn_count == 1 and any(
isinstance(decor, ast.Call)
Expand Down Expand Up @@ -370,6 +376,7 @@ def recompile_fn_with_transform(
initial_scope={},
pass_scope=False,
special_use_instance_type=None,
is_maxray_root=False,
) -> Result[Callable, str]:
"""
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
Expand Down Expand Up @@ -415,12 +422,6 @@ def recompile_fn_with_transform(
except SyntaxError:
return Err(f"Syntax error in function {get_fn_name(source_fn)}")

if "super()" in source:
# TODO: we could replace calls to super() with super(__class__, self)?
return Err(
f"Function {get_fn_name(source_fn)} cannot be transformed because it calls super()"
)

match fn_ast:
case ast.Module(body=[ast.FunctionDef() | ast.AsyncFunctionDef()]):
# Good
Expand Down Expand Up @@ -458,25 +459,38 @@ def recompile_fn_with_transform(
sourcefile,
fn_call_counter,
)
instance_type = parent_cls.__name__ if fn_is_method else None
transformed_fn_ast = FnRewriter(
transform_fn,
fn_context,
instance_type=parent_cls.__name__ if fn_is_method else None,
instance_type=instance_type,
dedent_chars=dedent_chars,
pass_locals_on_return=pass_scope,
is_maxray_root=is_maxray_root,
).visit(fn_ast)
ast.fix_missing_locations(transformed_fn_ast)

if ast_post_callback is not None:
ast_post_callback(transformed_fn_ast)

def patch_mro(super_type: super):
for parent_type in super_type.__self_class__.mro():
# Ok that's weird - this function gets picked up by the maxray decorator and seems to correctly patch the parent types - so despite looking like this function does absolutely nothing, it actually *has* side-effects
if not hasattr(parent_type, "__init__") or hasattr(
parent_type.__init__, "_MAXRAY_TRANSFORMED"
):
continue

return super_type

scope = {
transform_fn.__name__: transform_fn,
NodeContext.__name__: NodeContext,
"_MAXRAY_FN_CONTEXT": fn_context,
"_MAXRAY_CALL_COUNTER": fn_call_counter,
"_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with,
"_MAXRAY_BUILTINS_LOCALS": locals,
"_MAXRAY_PATCH_MRO": patch_mro,
}
scope.update(initial_scope)

Expand Down Expand Up @@ -568,13 +582,13 @@ def extract_cell(cell):

if fn_is_method:
transformed_fn = scope[
f"{source_fn.__name__}_{_METHOD_MANGLE_NAME}_{source_fn.__name__}"
f"{source_fn.__name__}_{instance_type}_{source_fn.__name__}"
]
else:
transformed_fn = scope[source_fn.__name__]

# a decorator doesn't actually have to return a function! (could be used solely for side effect) e.g. `@register_backend_lookup_factory` for `find_content_backend` in `awkward/contents/content.py`
if not callable(transformed_fn):
if not callable(transformed_fn) and not inspect.ismethoddescriptor(transformed_fn):
return Err(
f"Resulting transform of definition of {get_fn_name(source_fn)} is not even callable (got {transform_fn}). Perhaps a decorator that returns None?"
)
Expand Down
93 changes: 89 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from dataclasses import dataclass
import functools

import pytest

from loguru import logger

logger.enable("maxray")


def increment_ints_by_one(x, ctx):
if isinstance(x, int):
Expand Down Expand Up @@ -402,7 +408,7 @@ def uh():
z = X()
return z()

assert uh() == 4
assert uh() == 3


def test_junk_annotations():
Expand Down Expand Up @@ -483,6 +489,85 @@ def f(n):
assert found_scope["z"] == 3


def test_wrap_unsound():
# TODO
pass
def test_class_super():
class F:
def __init__(self, **kwargs):
super().__init__()

def f(self):
return 1

class G(F):
def __init__(self, x, y):
self.x = x
self.y = y
super().__init__(x=x)

@maxray(increment_ints_by_one)
def fn(f, g):
# this errors
G(1, 2)
G(1, 2)
return 4

@maxray(increment_ints_by_one)
def fn_works(f, g):
# this doesn't
F
G
G(1, 2)
G(1, 2)
return 4

# When given an *instance*, the __init__ in F overwrites G.__init__
# However, given just F, it correctly patches F.__init__

f_instance = F()
g_instance = G(1, 2)
assert fn(f_instance, g_instance) == fn_works(f_instance, g_instance) == 5


@pytest.mark.xfail
def test_class_super_explicit():
class H0:
def __init__(self, **kwargs):
super().__init__()

def f(self):
return 1

class H1(H0):
def __init__(self, x, y):
self.x = x
self.y = y
# This is currently not handled correctly
super(H1, self).__init__(x=x)

@maxray(increment_ints_by_one)
def fn():
H1(1, 2)
H1(1, 2)
return 4

assert fn() == 5


def test_super_classmethod():
class S0:
def __init__(self, **kwargs):
super().__init__()

@classmethod
def foo(cls):
return 1

class S1(S0):
@classmethod
def foo(cls):
return super().foo()

@maxray(increment_ints_by_one)
def fff():
return S1.foo()

assert fff() == 6

0 comments on commit cbd7b48

Please sign in to comment.