Skip to content

Commit

Permalink
feat: more context props (call, iterate, return)
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Sep 11, 2024
1 parent 7e61361 commit 7b51b32
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 49 deletions.
119 changes: 83 additions & 36 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ def _set_assigned(self, targets: list[str]):
self.props["assigned"] = {"targets": targets}
return self

def _set_iterated(self, target: str):
self.props["iterated"] = {"target": target}
return self

def _set_returned(self, value_source: str):
self.props["returned"] = {"value_source": value_source}
return self

def _set_called(self, call_args, call_kwargs):
# TODO: extract args and kwargs for call matching
# TODO: bind and pass TRANSFORM_ID if found
self.props["called"] = {}
return self


class RewriteRuntimeHelper:
"""
Expand Down Expand Up @@ -175,6 +189,27 @@ def assigned(self, assign_targets: list[str]):
keywords=[],
)

def iterated(self, iter_target: str):
self.args[1] = ast.Call(
ast.Attribute(self.args[1], "_set_iterated", ctx=ast.Load()),
args=[ast.Constant(iter_target)],
keywords=[],
)

def returned(self, return_target: str):
self.args[1] = ast.Call(
ast.Attribute(self.args[1], "_set_returned", ctx=ast.Load()),
args=[ast.Constant(return_target)],
keywords=[],
)

def called(self, call_args, call_kwargs):
self.args[1] = ast.Call(
ast.Attribute(self.args[1], "_set_called", ctx=ast.Load()),
args=[ast.Constant(None), ast.Constant(None)],
keywords=[],
)


class FnRewriter(ast.NodeTransformer):
def __init__(
Expand Down Expand Up @@ -353,22 +388,6 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
)
return node

def visit_match_case(self, node: ast.match_case) -> Any:
# leave node.pattern unchanged because the rules of match patterns are different from the rest of Python
# throws "ValueError: MatchClass cls field can only contain Name or Attribute nodes." in compile because `case _wrap(str()):` doesn't work
node.body = [self.generic_visit(child) for child in node.body]
return node

def visit_Assign(self, node: ast.Assign) -> Any:
node = deepcopy(node)
new_node = self.generic_visit(node)
match new_node:
case ast.Assign(targets=targets, value=RewriteTransformCall() as rtc):
target_reprs = [self.recover_source(t) for t in targets]
rtc.assigned(target_reprs)

return new_node

def visit_Subscript(self, node: ast.Subscript) -> Any:
if isinstance(node.ctx, ast.Load):
source_pre = self.recover_source(node)
Expand All @@ -378,6 +397,15 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
return ast.copy_location(node, node_pre)
return node

def visit_Constant(self, node: ast.Constant) -> Any:
source_pre = self.recover_source(node)
node_pre = deepcopy(node)
new_node = self.generic_visit(node)
new_node = self.build_transform_node(
new_node, "constant", node_source=source_pre
)
return ast.copy_location(new_node, node_pre)

def visit_BinOp(self, node: ast.BinOp) -> Any:
source_pre = self.recover_source(node)
node_pre = deepcopy(node)
Expand All @@ -389,34 +417,48 @@ def visit_BinOp(self, node: ast.BinOp) -> Any:
)
return ast.copy_location(node, node_pre)

def visit_match_case(self, node: ast.match_case) -> Any:
# leave node.pattern unchanged because the rules of match patterns are different from the rest of Python
# throws "ValueError: MatchClass cls field can only contain Name or Attribute nodes." in compile because `case _wrap(str()):` doesn't work
node.body = [self.generic_visit(child) for child in node.body]
return node

# Non-expression nodes

def visit_Assign(self, node: ast.Assign) -> Any:
node = deepcopy(node)
new_node = self.generic_visit(node)
match new_node:
case ast.Assign(targets=targets, value=RewriteTransformCall() as rtc):
target_reprs = [self.recover_source(t) for t in targets]
rtc.assigned(target_reprs)

return new_node

def visit_For(self, node: ast.For) -> Any:
node = deepcopy(node)
new_node = self.generic_visit(node)
match new_node:
case ast.For(target=target, iter=RewriteTransformCall() as rtc):
rtc.iterated(self.recover_source(target))

return new_node

def visit_Return(self, node: ast.Return) -> Any:
"""
`return` is a non-expression node. Though adding an event on this node is redundant (callback would already be invoked on the expression to be returned), it's still useful to track what was returned or override it.
"""
source_pre = self.recover_source(node)
node_pre = deepcopy(node)
node = deepcopy(node)

# Note: For a plain `return` statement, there's no source for a thing that *isn't* returned
value_source_pre = (
"None" if node.value is None else self.recover_source(node.value)
)

if node.value is None:
# # Don't want to invoke a callback on a node that doesn't actually exist
node.value = ast.Constant(None)
value_source = ""
else:
node = self.generic_visit(node)
value_source = self.recover_source(node.value)

# TODO: Check source locations are correct here
ast.fix_missing_locations(node)
node.value = self.build_transform_node(
node.value,
f"return/{value_source_pre}",
node_source=value_source_pre,
)
new_node = self.generic_visit(node)
match new_node:
case ast.Return(value=RewriteTransformCall() as rtc):
rtc.returned(value_source)

return ast.copy_location(node, node_pre)
return new_node

@staticmethod
def temp_binding(node):
Expand Down Expand Up @@ -480,6 +522,11 @@ def visit_Call(self, node):
ast.fix_missing_locations(node)

node = self.generic_visit(node)

match node:
case ast.Call(func=RewriteTransformCall() as rtc):
rtc.called(node.args, node.keywords)

# Want to keep track of which function we're calling
# Can't do ops on `node_pre.func` beacuse evaluating it has side effects
# `node.func` is likely a call to `_maxray_walker_handler`
Expand Down
2 changes: 1 addition & 1 deletion tests/test_script_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def nonzero(x, ctx):
if isinstance(x, int) and x == 0:
if ctx.id != "constant" and isinstance(x, int) and x == 0:
return 1
return x

Expand Down
22 changes: 10 additions & 12 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_basic():
def f(x):
return x

assert f(3) == 5
assert f(3) == 4


def test_type_hints():
Expand All @@ -43,7 +43,7 @@ def test_type_hints():
def f(x: Any):
return x

assert f(3) == 5
assert f(3) == 4


def test_closure_capture():
Expand Down Expand Up @@ -90,8 +90,8 @@ def g(x):

return g

assert outer()(3) == 5
assert outer()(3) == 5
assert outer()(3) == 4
assert outer()(3) == 4


def test_recursive():
Expand Down Expand Up @@ -179,7 +179,7 @@ class A:
def g():
return obj.x

assert g() == 3
assert g() == 2


def test_method():
Expand All @@ -193,7 +193,7 @@ def g():
a = A()
return a.foo()

assert g() == "2"
assert g() == "3"


def test_recursive_self_repr():
Expand Down Expand Up @@ -385,12 +385,11 @@ def test_xray_immutable():
@maxray(lambda x, ctx: x * 10 if isinstance(x, float) else x)
@xray(increment_ints_by_one)
def foo():
# Currently assumes that literals/constants are not wrapped (they're uninteresting anyways)
x = 1
y = 2.0
return x, y

assert foo() == (1, 20.0)
assert foo() == (1, 200.0)


def test_walk_callable_side_effects():
Expand Down Expand Up @@ -440,7 +439,7 @@ def dec(f):
def f(x):
return x

assert f(2) == 4
assert f(2) == 3
assert len(decor_count) == 1

# Works properly when applied last: is wiped for the transform, but is subsequently applied properly to the transformed function
Expand All @@ -449,7 +448,7 @@ def f(x):
def f(x):
return x

assert f(2) == 0
assert f(2) == 1
assert len(decor_count) == 2


Expand Down Expand Up @@ -624,10 +623,9 @@ def foo(cls):

@maxray(increment_ints_by_one)
def fff():
# S0.foo
return S1.foo()

assert fff() == 6
assert fff() == 4


def test_partialmethod():
Expand Down

0 comments on commit 7b51b32

Please sign in to comment.