Skip to content

Commit

Permalink
feat: track return nodes, optionally collecting the local scope on exit
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed May 14, 2024
1 parent e842dfe commit 8afe9ad
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 40 deletions.
18 changes: 14 additions & 4 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def inner(fn):
return inner


def xray(walker):
def xray(walker, **kwargs):
"""
Immutable version of `maxray` - expressions are passed to `walker` but its return value is ignored and the original code execution is left unchanged.
"""
return maxray(walker, mutable=False)
return maxray(walker, **kwargs, mutable=False)


_GLOBAL_SKIP_MODULES = {
Expand Down Expand Up @@ -65,6 +65,9 @@ class W_erHook:


def callable_allowed_for_transform(x, ctx: NodeContext):
if getattr(x, "__module__", None) in _GLOBAL_SKIP_MODULES:
return False

module_path = ctx.fn_context.module.split(".")
if module_path[0] in _GLOBAL_SKIP_MODULES:
return False
Expand Down Expand Up @@ -155,7 +158,11 @@ def _maxray_walker_handler(x, ctx: NodeContext):


def maxray(
writer: Callable[[Any, NodeContext], Any], skip_modules=frozenset(), *, mutable=True
writer: Callable[[Any, NodeContext], Any],
skip_modules=frozenset(),
*,
mutable=True,
pass_scope=False,
):
"""
A transform that recursively hooks into all further calls made within the function, so that `writer` will (in theory) observe every single expression evaluated by the Python interpreter occurring as part of the decorated function call.
Expand Down Expand Up @@ -198,7 +205,10 @@ def recursive_transform(fn):
fn_transform = fn
else:
match recompile_fn_with_transform(
fn, _maxray_walker_handler, initial_scope=caller_locals
fn,
_maxray_walker_handler,
initial_scope=caller_locals,
pass_scope=pass_scope,
):
case Ok(fn_transform):
pass
Expand Down
81 changes: 57 additions & 24 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class FnContext:
call_count: ContextVar[int]

def __repr__(self):
return f"{self.module}/{self.name}/{self.call_count.get()}"
# Call count not included in repr so the same source location can be "grouped by" over multiple calls
return f"{self.module}/{self.name}"


@dataclass
Expand All @@ -57,6 +58,8 @@ class NodeContext:

location: tuple[int, int, int, int]

local_scope: Any = None

def __repr__(self):
return f"{self.fn_context}/{self.id}"

Expand All @@ -70,6 +73,7 @@ def __init__(
instance_type: str | None,
dedent_chars: int = 0,
record_call_counts: bool = True,
pass_locals_on_return: bool = False,
):
"""
If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None.
Expand All @@ -80,6 +84,7 @@ def __init__(
self.instance_type = instance_type
self.dedent_chars = dedent_chars
self.record_call_counts = record_call_counts
self.pass_locals_on_return = pass_locals_on_return

# the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions.
self.fn_count = 0
Expand Down Expand Up @@ -113,7 +118,7 @@ def recover_source(self, pre_node):
return self.safe_unparse(pre_node)
return segment

def build_transform_node(self, node, label, node_source=None):
def build_transform_node(self, node, label, node_source=None, pass_locals=False):
"""
Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`.
"""
Expand All @@ -122,22 +127,33 @@ def build_transform_node(self, node, label, node_source=None):

line_offset = self.fn_context.impl_fn.__code__.co_firstlineno - 2
col_offset = self.dedent_chars

context_args = [
ast.Constant(label),
ast.Constant(node_source),
# Name is injected into the exec scope by `recompile_fn_with_transform`
ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()),
ast.Constant(
(
line_offset + node.lineno,
line_offset + node.end_lineno,
node.col_offset + col_offset,
node.end_col_offset + col_offset,
)
),
]

if pass_locals:
context_args.append(
ast.Call(
func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
args=[],
keywords=[],
),
)
context_node = ast.Call(
func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()),
args=[
ast.Constant(label),
ast.Constant(node_source),
# Name is injected into the exec scope by `recompile_fn_with_transform`
ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()),
ast.Constant(
(
line_offset + node.lineno,
line_offset + node.end_lineno,
node.col_offset + col_offset,
node.end_col_offset + col_offset,
)
),
],
args=context_args,
keywords=[],
)

Expand Down Expand Up @@ -196,6 +212,28 @@ def visit_Assign(self, node: ast.Assign) -> Any:
# node.value = self.build_transform_node(new_node, f"assign/(multiple)")
return node

def visit_Return(self, node: ast.Return) -> Any:
node_pre = deepcopy(node)

if node.value is None:
node.value = ast.Constant(None)

# Note: For a plain `return` statement, there's no source for a thing that *isn't* returned
value_source_pre = self.recover_source(node.value)

node = self.generic_visit(node)

# TODO: Check source locations are correct here
ast.fix_missing_locations(node)
node.value = self.build_transform_node(
node.value,
f"return/{value_source_pre}",
node_source=value_source_pre,
pass_locals=self.pass_locals_on_return,
)

return ast.copy_location(node, node_pre)

def visit_Call(self, node):
source_pre = self.recover_source(node)

Expand All @@ -204,14 +242,6 @@ def visit_Call(self, node):
node = self.generic_visit(node) # mutates

# the function/callable instance itself is observed by Name/Attribute/... nodes

target = node.func
match target:
case ast.Name():
logger.debug(f"Visiting call to function {target.id}")
case ast.Attribute():
logger.debug(f"Visiting call to attribute {target.attr}")

return ast.copy_location(
self.build_transform_node(
node, f"call/{source_pre}", node_source=source_pre
Expand Down Expand Up @@ -338,6 +368,7 @@ def recompile_fn_with_transform(
ast_pre_callback=None,
ast_post_callback=None,
initial_scope={},
pass_scope=False,
) -> Result[Callable, str]:
"""
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
Expand Down Expand Up @@ -421,6 +452,7 @@ def recompile_fn_with_transform(
fn_context,
instance_type=parent_cls.__name__ if fn_is_method else None,
dedent_chars=dedent_chars,
pass_locals_on_return=pass_scope,
).visit(fn_ast)
ast.fix_missing_locations(transformed_fn_ast)

Expand All @@ -433,6 +465,7 @@ def recompile_fn_with_transform(
"_MAXRAY_FN_CONTEXT": fn_context,
"_MAXRAY_CALL_COUNTER": fn_call_counter,
"_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with,
"_MAXRAY_BUILTINS_LOCALS": locals,
}
scope.update(initial_scope)

Expand Down
53 changes: 41 additions & 12 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_basic():
def f(x):
return x

assert f(3) == 4
assert f(3) == 5


def test_type_hints():
Expand All @@ -29,7 +29,7 @@ def test_type_hints():
def f(x: Any):
return x

assert f(3) == 4
assert f(3) == 5


def test_closure_capture():
Expand All @@ -39,7 +39,7 @@ def test_closure_capture():
def f(x):
return x + z

assert f(3) == 6
assert f(3) == 7


def test_closure_capture_mutate():
Expand All @@ -62,7 +62,7 @@ def test_global_capture():
def g(x):
return x + GLOB_CONST

assert g(3) == 10
assert g(3) == 11


def test_nested_def():
Expand All @@ -76,8 +76,8 @@ def g(x):

return g

assert outer()(3) == 4
assert outer()(3) == 4
assert outer()(3) == 5
assert outer()(3) == 5


def test_recursive():
Expand Down Expand Up @@ -150,7 +150,7 @@ def f():
pass
return x

assert f() == 4
assert f() == 8


def test_property_access():
Expand All @@ -165,7 +165,7 @@ class A:
def g():
return obj.x

assert g() == 2
assert g() == 3


def test_method():
Expand Down Expand Up @@ -377,7 +377,7 @@ def dec(f):
def f(x):
return x

assert f(2) == 3
assert f(2) == 4
assert len(decor_count) == 1

# Works properly when applied last: is wiped for the transform, but is subsequently applied properly to the transformed function
Expand All @@ -386,7 +386,7 @@ def f(x):
def f(x):
return x

assert f(2) == 1
assert f(2) == 0
assert len(decor_count) == 2


Expand All @@ -402,7 +402,7 @@ def uh():
z = X()
return z()

assert uh() == 2
assert uh() == 3


def test_junk_annotations():
Expand All @@ -413,7 +413,7 @@ def inner(x: ASDF = 0, *, y: SDFSDF = 100) -> AAAAAAAAAAA:

return inner(2)

assert outer() == 105
assert outer() == 107


def test_call_counts():
Expand Down Expand Up @@ -454,6 +454,35 @@ def f(x):
assert calls == [1, 2, 3, 3, 2, 1]


def test_empty_return():
@xray(dbg)
def empty_returns():
return

assert empty_returns() is None


def test_scope_passed():
found_scope = None

def get_scope(x, ctx):
nonlocal found_scope
if ctx.local_scope is not None:
assert found_scope is None
found_scope = ctx.local_scope
return x

@xray(get_scope, pass_scope=True)
def f(n):
z = 3
return n

assert f(1) == 1

assert "z" in found_scope
assert found_scope["z"] == 3


def test_wrap_unsound():
# TODO
pass

0 comments on commit 8afe9ad

Please sign in to comment.