diff --git a/maxray/__init__.py b/maxray/__init__.py index 29a5df9..d053a9c 100644 --- a/maxray/__init__.py +++ b/maxray/__init__.py @@ -9,6 +9,9 @@ from loguru import logger +# Avoid logspam for users of the library +logger.disable("maxray") + def transform(writer): """ @@ -16,7 +19,7 @@ def transform(writer): """ def inner(fn): - match recompile_fn_with_transform(fn, writer): + match recompile_fn_with_transform(fn, writer, is_maxray_root=True): case Ok(trans_fn): return wraps(fn)(trans_fn) case Err(err): @@ -39,7 +42,11 @@ def xray(walker, **kwargs): "pathlib", # internally used in transform for checking source file exists "re", # internals of regexp have a lot of uninteresting step methods "copy", # pytorch spends so much time in here + "typing", + "importlib", + "loguru", # used internally in transform - accidental patching will cause inf recursion } +# TODO: probably just skip the entire Python standard library... @dataclass @@ -77,6 +84,7 @@ def callable_allowed_for_transform(x, ctx: NodeContext): and callable(x) and callable(getattr(x, "__hash__", None)) and getattr(type(x), "__module__", None) not in {"ctypes"} + and (inspect.isfunction(x) or inspect.ismethod(x)) ) @@ -118,7 +126,7 @@ def _maxray_walker_handler(x, ctx: NodeContext): logger.info(f"Patching __init__ for class {x}") setattr(x, "__init__", init_patch) - if instance_call_allowed_for_transform(x, ctx): + elif instance_call_allowed_for_transform(x, ctx): # TODO: should we somehow delay doing this until before an actual call? match recompile_fn_with_transform( x.__call__, _maxray_walker_handler, special_use_instance_type=x @@ -128,8 +136,8 @@ def _maxray_walker_handler(x, ctx: NodeContext): setattr(x, "__call__", call_patch) # 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod - if callable_allowed_for_transform(x, ctx): - # TODO: don't cache objects w/ __call__ + elif callable_allowed_for_transform(x, ctx): + # We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload if x in _MAXRAY_FN_CACHE: x = _MAXRAY_FN_CACHE[x] elif not any( @@ -174,8 +182,8 @@ def _maxray_walker_handler(x, ctx: NodeContext): case Err(e): # Cache failures _MAXRAY_FN_CACHE[x] = x - # Errors in functions that have been recursively compiled are unimportant - logger.trace(f"Failed to transform in walker handler: {e}") + # Errors in functions that have been recursively compiled are less important + logger.warning(f"Failed to transform in walker handler: {e}") # 2. run the active hooks global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True) @@ -250,6 +258,7 @@ def recursive_transform(fn): _maxray_walker_handler, initial_scope=caller_locals, pass_scope=pass_scope, + is_maxray_root=True, ): case Ok(fn_transform): pass diff --git a/maxray/transforms.py b/maxray/transforms.py index 735a018..faa912c 100644 --- a/maxray/transforms.py +++ b/maxray/transforms.py @@ -74,6 +74,7 @@ def __init__( dedent_chars: int = 0, record_call_counts: bool = True, pass_locals_on_return: bool = False, + is_maxray_root: bool = False, ): """ If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None. @@ -85,6 +86,7 @@ def __init__( self.dedent_chars = dedent_chars self.record_call_counts = record_call_counts self.pass_locals_on_return = pass_locals_on_return + self.is_maxray_root = is_maxray_root # the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions. self.fn_count = 0 @@ -236,9 +238,29 @@ def visit_Return(self, node: ast.Return) -> Any: def visit_Call(self, node): source_pre = self.recover_source(node) - node_pre = deepcopy(node) + match node: + case ast.Call(func=ast.Name(id="super"), args=[]): + # HACK: STUPID HACKY HACK + # TODO: detect @classmethod properly + if "self" in self.fn_context.source: + node.args = [ + ast.Name("__class__", ctx=ast.Load()), + ast.Name("self", ctx=ast.Load()), + ] + else: + node.args = [ + ast.Name("__class__", ctx=ast.Load()), + ast.Name("cls", ctx=ast.Load()), + ] + node = ast.Call( + func=ast.Name("_MAXRAY_PATCH_MRO", ctx=ast.Load()), + args=[node], + keywords=[], + ) + ast.fix_missing_locations(node) + node = self.generic_visit(node) # mutates # the function/callable instance itself is observed by Name/Attribute/... nodes @@ -249,32 +271,16 @@ def visit_Call(self, node): node_pre, ) - id_key = f"{self.context_fn.__name__}/call/{source_pre}" - # Observes the *output* of the function - call_observer = ast.Call( - func=ast.Name(id=self.transform_fn.__name__, ctx=ast.Load()), - args=[node, ast.Constant({"id": id_key})], - keywords=[], - ) - return call_observer - def visit_FunctionDef(self, node: ast.FunctionDef): pre_node = deepcopy(node) self.fn_count += 1 # Only overwrite the name of our "target function" if self.fn_count == 1 and self.is_method(): - node.name = f"{node.name}_{_METHOD_MANGLE_NAME}_{node.name}" - - # TODO: add is_root arg (whether fn has directly been decorated with @*xray) + node.name = f"{node.name}_{self.instance_type}_{node.name}" # Decorators are evaluated sequentially: decorators applied *before* our one (should?) get ignored while decorators applied *after* work correctly - is_transform_root = self.fn_count == 1 and any( - isinstance(decor, ast.Call) - and isinstance(decor.func, ast.Name) - and decor.func.id in {"maxray", "xray", "transform"} - for decor in node.decorator_list - ) + is_transform_root = self.fn_count == 1 and self.is_maxray_root if is_transform_root: logger.info( @@ -319,7 +325,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: self.fn_count += 1 # Only overwrite the name of our "target function" if self.fn_count == 1 and self.is_method(): - node.name = f"{node.name}_{_METHOD_MANGLE_NAME}_{node.name}" + node.name = f"{node.name}_{self.instance_type}_{node.name}" is_transform_root = self.fn_count == 1 and any( isinstance(decor, ast.Call) @@ -370,6 +376,7 @@ def recompile_fn_with_transform( initial_scope={}, pass_scope=False, special_use_instance_type=None, + is_maxray_root=False, ) -> Result[Callable, str]: """ Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code. @@ -415,12 +422,6 @@ def recompile_fn_with_transform( except SyntaxError: return Err(f"Syntax error in function {get_fn_name(source_fn)}") - if "super()" in source: - # TODO: we could replace calls to super() with super(__class__, self)? - return Err( - f"Function {get_fn_name(source_fn)} cannot be transformed because it calls super()" - ) - match fn_ast: case ast.Module(body=[ast.FunctionDef() | ast.AsyncFunctionDef()]): # Good @@ -458,18 +459,30 @@ def recompile_fn_with_transform( sourcefile, fn_call_counter, ) + instance_type = parent_cls.__name__ if fn_is_method else None transformed_fn_ast = FnRewriter( transform_fn, fn_context, - instance_type=parent_cls.__name__ if fn_is_method else None, + instance_type=instance_type, dedent_chars=dedent_chars, pass_locals_on_return=pass_scope, + is_maxray_root=is_maxray_root, ).visit(fn_ast) ast.fix_missing_locations(transformed_fn_ast) if ast_post_callback is not None: ast_post_callback(transformed_fn_ast) + def patch_mro(super_type: super): + for parent_type in super_type.__self_class__.mro(): + # Ok that's weird - this function gets picked up by the maxray decorator and seems to correctly patch the parent types - so despite looking like this function does absolutely nothing, it actually *has* side-effects + if not hasattr(parent_type, "__init__") or hasattr( + parent_type.__init__, "_MAXRAY_TRANSFORMED" + ): + continue + + return super_type + scope = { transform_fn.__name__: transform_fn, NodeContext.__name__: NodeContext, @@ -477,6 +490,7 @@ def recompile_fn_with_transform( "_MAXRAY_CALL_COUNTER": fn_call_counter, "_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with, "_MAXRAY_BUILTINS_LOCALS": locals, + "_MAXRAY_PATCH_MRO": patch_mro, } scope.update(initial_scope) @@ -568,13 +582,13 @@ def extract_cell(cell): if fn_is_method: transformed_fn = scope[ - f"{source_fn.__name__}_{_METHOD_MANGLE_NAME}_{source_fn.__name__}" + f"{source_fn.__name__}_{instance_type}_{source_fn.__name__}" ] else: transformed_fn = scope[source_fn.__name__] # a decorator doesn't actually have to return a function! (could be used solely for side effect) e.g. `@register_backend_lookup_factory` for `find_content_backend` in `awkward/contents/content.py` - if not callable(transformed_fn): + if not callable(transformed_fn) and not inspect.ismethoddescriptor(transformed_fn): return Err( f"Resulting transform of definition of {get_fn_name(source_fn)} is not even callable (got {transform_fn}). Perhaps a decorator that returns None?" ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index d796bbe..a83e698 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,6 +6,12 @@ from dataclasses import dataclass import functools +import pytest + +from loguru import logger + +logger.enable("maxray") + def increment_ints_by_one(x, ctx): if isinstance(x, int): @@ -402,7 +408,7 @@ def uh(): z = X() return z() - assert uh() == 4 + assert uh() == 3 def test_junk_annotations(): @@ -483,6 +489,85 @@ def f(n): assert found_scope["z"] == 3 -def test_wrap_unsound(): - # TODO - pass +def test_class_super(): + class F: + def __init__(self, **kwargs): + super().__init__() + + def f(self): + return 1 + + class G(F): + def __init__(self, x, y): + self.x = x + self.y = y + super().__init__(x=x) + + @maxray(increment_ints_by_one) + def fn(f, g): + # this errors + G(1, 2) + G(1, 2) + return 4 + + @maxray(increment_ints_by_one) + def fn_works(f, g): + # this doesn't + F + G + G(1, 2) + G(1, 2) + return 4 + + # When given an *instance*, the __init__ in F overwrites G.__init__ + # However, given just F, it correctly patches F.__init__ + + f_instance = F() + g_instance = G(1, 2) + assert fn(f_instance, g_instance) == fn_works(f_instance, g_instance) == 5 + + +@pytest.mark.xfail +def test_class_super_explicit(): + class H0: + def __init__(self, **kwargs): + super().__init__() + + def f(self): + return 1 + + class H1(H0): + def __init__(self, x, y): + self.x = x + self.y = y + # This is currently not handled correctly + super(H1, self).__init__(x=x) + + @maxray(increment_ints_by_one) + def fn(): + H1(1, 2) + H1(1, 2) + return 4 + + assert fn() == 5 + + +def test_super_classmethod(): + class S0: + def __init__(self, **kwargs): + super().__init__() + + @classmethod + def foo(cls): + return 1 + + class S1(S0): + @classmethod + def foo(cls): + return super().foo() + + @maxray(increment_ints_by_one) + def fff(): + return S1.foo() + + assert fff() == 6