From 8afe9adcf9ddd186d9978cd35ba95ab05dbd3b26 Mon Sep 17 00:00:00 2001 From: blepabyte <255@blepabyte.me> Date: Tue, 14 May 2024 22:13:28 +1200 Subject: [PATCH] feat: track return nodes, optionally collecting the local scope on exit --- maxray/__init__.py | 18 +++++++-- maxray/transforms.py | 81 ++++++++++++++++++++++++++++------------ tests/test_transforms.py | 53 ++++++++++++++++++++------ 3 files changed, 112 insertions(+), 40 deletions(-) diff --git a/maxray/__init__.py b/maxray/__init__.py index f74aa5f..f11c778 100644 --- a/maxray/__init__.py +++ b/maxray/__init__.py @@ -26,11 +26,11 @@ def inner(fn): return inner -def xray(walker): +def xray(walker, **kwargs): """ Immutable version of `maxray` - expressions are passed to `walker` but its return value is ignored and the original code execution is left unchanged. """ - return maxray(walker, mutable=False) + return maxray(walker, **kwargs, mutable=False) _GLOBAL_SKIP_MODULES = { @@ -65,6 +65,9 @@ class W_erHook: def callable_allowed_for_transform(x, ctx: NodeContext): + if getattr(x, "__module__", None) in _GLOBAL_SKIP_MODULES: + return False + module_path = ctx.fn_context.module.split(".") if module_path[0] in _GLOBAL_SKIP_MODULES: return False @@ -155,7 +158,11 @@ def _maxray_walker_handler(x, ctx: NodeContext): def maxray( - writer: Callable[[Any, NodeContext], Any], skip_modules=frozenset(), *, mutable=True + writer: Callable[[Any, NodeContext], Any], + skip_modules=frozenset(), + *, + mutable=True, + pass_scope=False, ): """ A transform that recursively hooks into all further calls made within the function, so that `writer` will (in theory) observe every single expression evaluated by the Python interpreter occurring as part of the decorated function call. @@ -198,7 +205,10 @@ def recursive_transform(fn): fn_transform = fn else: match recompile_fn_with_transform( - fn, _maxray_walker_handler, initial_scope=caller_locals + fn, + _maxray_walker_handler, + initial_scope=caller_locals, + pass_scope=pass_scope, ): case Ok(fn_transform): pass diff --git a/maxray/transforms.py b/maxray/transforms.py index 588f85b..ccb0011 100644 --- a/maxray/transforms.py +++ b/maxray/transforms.py @@ -36,7 +36,8 @@ class FnContext: call_count: ContextVar[int] def __repr__(self): - return f"{self.module}/{self.name}/{self.call_count.get()}" + # Call count not included in repr so the same source location can be "grouped by" over multiple calls + return f"{self.module}/{self.name}" @dataclass @@ -57,6 +58,8 @@ class NodeContext: location: tuple[int, int, int, int] + local_scope: Any = None + def __repr__(self): return f"{self.fn_context}/{self.id}" @@ -70,6 +73,7 @@ def __init__( instance_type: str | None, dedent_chars: int = 0, record_call_counts: bool = True, + pass_locals_on_return: bool = False, ): """ If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None. @@ -80,6 +84,7 @@ def __init__( self.instance_type = instance_type self.dedent_chars = dedent_chars self.record_call_counts = record_call_counts + self.pass_locals_on_return = pass_locals_on_return # the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions. self.fn_count = 0 @@ -113,7 +118,7 @@ def recover_source(self, pre_node): return self.safe_unparse(pre_node) return segment - def build_transform_node(self, node, label, node_source=None): + def build_transform_node(self, node, label, node_source=None, pass_locals=False): """ Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`. """ @@ -122,22 +127,33 @@ def build_transform_node(self, node, label, node_source=None): line_offset = self.fn_context.impl_fn.__code__.co_firstlineno - 2 col_offset = self.dedent_chars + + context_args = [ + ast.Constant(label), + ast.Constant(node_source), + # Name is injected into the exec scope by `recompile_fn_with_transform` + ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()), + ast.Constant( + ( + line_offset + node.lineno, + line_offset + node.end_lineno, + node.col_offset + col_offset, + node.end_col_offset + col_offset, + ) + ), + ] + + if pass_locals: + context_args.append( + ast.Call( + func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()), + args=[], + keywords=[], + ), + ) context_node = ast.Call( func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()), - args=[ - ast.Constant(label), - ast.Constant(node_source), - # Name is injected into the exec scope by `recompile_fn_with_transform` - ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()), - ast.Constant( - ( - line_offset + node.lineno, - line_offset + node.end_lineno, - node.col_offset + col_offset, - node.end_col_offset + col_offset, - ) - ), - ], + args=context_args, keywords=[], ) @@ -196,6 +212,28 @@ def visit_Assign(self, node: ast.Assign) -> Any: # node.value = self.build_transform_node(new_node, f"assign/(multiple)") return node + def visit_Return(self, node: ast.Return) -> Any: + node_pre = deepcopy(node) + + if node.value is None: + node.value = ast.Constant(None) + + # Note: For a plain `return` statement, there's no source for a thing that *isn't* returned + value_source_pre = self.recover_source(node.value) + + node = self.generic_visit(node) + + # TODO: Check source locations are correct here + ast.fix_missing_locations(node) + node.value = self.build_transform_node( + node.value, + f"return/{value_source_pre}", + node_source=value_source_pre, + pass_locals=self.pass_locals_on_return, + ) + + return ast.copy_location(node, node_pre) + def visit_Call(self, node): source_pre = self.recover_source(node) @@ -204,14 +242,6 @@ def visit_Call(self, node): node = self.generic_visit(node) # mutates # the function/callable instance itself is observed by Name/Attribute/... nodes - - target = node.func - match target: - case ast.Name(): - logger.debug(f"Visiting call to function {target.id}") - case ast.Attribute(): - logger.debug(f"Visiting call to attribute {target.attr}") - return ast.copy_location( self.build_transform_node( node, f"call/{source_pre}", node_source=source_pre @@ -338,6 +368,7 @@ def recompile_fn_with_transform( ast_pre_callback=None, ast_post_callback=None, initial_scope={}, + pass_scope=False, ) -> Result[Callable, str]: """ Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code. @@ -421,6 +452,7 @@ def recompile_fn_with_transform( fn_context, instance_type=parent_cls.__name__ if fn_is_method else None, dedent_chars=dedent_chars, + pass_locals_on_return=pass_scope, ).visit(fn_ast) ast.fix_missing_locations(transformed_fn_ast) @@ -433,6 +465,7 @@ def recompile_fn_with_transform( "_MAXRAY_FN_CONTEXT": fn_context, "_MAXRAY_CALL_COUNTER": fn_call_counter, "_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with, + "_MAXRAY_BUILTINS_LOCALS": locals, } scope.update(initial_scope) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1ec8c0e..fa8a8a1 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -19,7 +19,7 @@ def test_basic(): def f(x): return x - assert f(3) == 4 + assert f(3) == 5 def test_type_hints(): @@ -29,7 +29,7 @@ def test_type_hints(): def f(x: Any): return x - assert f(3) == 4 + assert f(3) == 5 def test_closure_capture(): @@ -39,7 +39,7 @@ def test_closure_capture(): def f(x): return x + z - assert f(3) == 6 + assert f(3) == 7 def test_closure_capture_mutate(): @@ -62,7 +62,7 @@ def test_global_capture(): def g(x): return x + GLOB_CONST - assert g(3) == 10 + assert g(3) == 11 def test_nested_def(): @@ -76,8 +76,8 @@ def g(x): return g - assert outer()(3) == 4 - assert outer()(3) == 4 + assert outer()(3) == 5 + assert outer()(3) == 5 def test_recursive(): @@ -150,7 +150,7 @@ def f(): pass return x - assert f() == 4 + assert f() == 8 def test_property_access(): @@ -165,7 +165,7 @@ class A: def g(): return obj.x - assert g() == 2 + assert g() == 3 def test_method(): @@ -377,7 +377,7 @@ def dec(f): def f(x): return x - assert f(2) == 3 + assert f(2) == 4 assert len(decor_count) == 1 # Works properly when applied last: is wiped for the transform, but is subsequently applied properly to the transformed function @@ -386,7 +386,7 @@ def f(x): def f(x): return x - assert f(2) == 1 + assert f(2) == 0 assert len(decor_count) == 2 @@ -402,7 +402,7 @@ def uh(): z = X() return z() - assert uh() == 2 + assert uh() == 3 def test_junk_annotations(): @@ -413,7 +413,7 @@ def inner(x: ASDF = 0, *, y: SDFSDF = 100) -> AAAAAAAAAAA: return inner(2) - assert outer() == 105 + assert outer() == 107 def test_call_counts(): @@ -454,6 +454,35 @@ def f(x): assert calls == [1, 2, 3, 3, 2, 1] +def test_empty_return(): + @xray(dbg) + def empty_returns(): + return + + assert empty_returns() is None + + +def test_scope_passed(): + found_scope = None + + def get_scope(x, ctx): + nonlocal found_scope + if ctx.local_scope is not None: + assert found_scope is None + found_scope = ctx.local_scope + return x + + @xray(get_scope, pass_scope=True) + def f(n): + z = 3 + return n + + assert f(1) == 1 + + assert "z" in found_scope + assert found_scope["z"] == 3 + + def test_wrap_unsound(): # TODO pass