From 6427aef210e118944c83dc89556f9a8ba9b72650 Mon Sep 17 00:00:00 2001 From: blepabyte <255@blepabyte.me> Date: Sun, 7 Apr 2024 17:31:03 +1200 Subject: [PATCH] feat: implement transform, xray, maxray --- README.md | 0 maxray/__init__.py | 199 ++++++++++++++++++ maxray/transforms.py | 421 +++++++++++++++++++++++++++++++++++++++ maxray/walkers.py | 11 + pyproject.toml | 16 ++ tests/__init__.py | 0 tests/test_transforms.py | 322 ++++++++++++++++++++++++++++++ 7 files changed, 969 insertions(+) create mode 100644 README.md create mode 100644 maxray/__init__.py create mode 100644 maxray/transforms.py create mode 100644 maxray/walkers.py create mode 100644 pyproject.toml create mode 100644 tests/__init__.py create mode 100644 tests/test_transforms.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/maxray/__init__.py b/maxray/__init__.py new file mode 100644 index 0000000..09fe8a8 --- /dev/null +++ b/maxray/__init__.py @@ -0,0 +1,199 @@ +from .transforms import recompile_fn_with_transform, NodeContext + +import inspect +from contextvars import ContextVar +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable +from result import Result, Ok, Err + +from loguru import logger + + +def transform(writer): + """ + Decorator that rewrites the source code of a function to wrap every single expression by passing it through the `writer(expr, ctx)` callable. + """ + + def inner(fn): + match recompile_fn_with_transform(fn, writer): + case Ok(trans_fn): + return wraps(fn)(trans_fn) + case Err(err): + logger.error(err) + return fn + + return inner + + +def xray(walker): + """ + Immutable version of `maxray` - expressions are passed to `walker` but its return value is ignored and the original code execution is left unchanged. + """ + return maxray(walker, mutable=False) + + +_GLOBAL_SKIP_MODULES = { + "abc", # excessive inheritance and super calls in scikit-learn + "inspect", # don't want to screw this module up + "pathlib", # internally used in transform for checking source file exists + "re", # internals of regexp have a lot of uninteresting step methods + "copy", # pytorch spends so much time in here +} + + +@dataclass +class W_erHook: + impl_fn: Callable + active_call_state: ContextVar[bool] + writer_active_call_state: ContextVar[bool] + skip_modules: set + only_modules: set + mutable: bool + + # each walker defines names to skip and we skip recursive transform if *any* walker asks to skip + + +_MAXRAY_REGISTERED_HOOKS: list[W_erHook] = [] + +_GLOBAL_WRITER_ACTIVE_FLAG = ContextVar("writer_active (global)", default=False) + +# We don't want to recompile the same function over and over - so our cache needs to be global +_MAXRAY_FN_CACHE = dict() + + +def callable_allowed_for_transform(x, ctx: NodeContext): + module_path = ctx.fn_context.module.split(".") + if module_path[0] in _GLOBAL_SKIP_MODULES: + return False + return not hasattr(x, "_MAXRAY_TRANSFORMED") and callable(x) + + +def _maxray_walker_handler(x, ctx): + # We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell + if _GLOBAL_WRITER_ACTIVE_FLAG.get(): + return x + + # 1. logic to recursively patch callables + if callable_allowed_for_transform(x, ctx): + if x in _MAXRAY_FN_CACHE: + return _MAXRAY_FN_CACHE[x] + + # Our recompiled fn sets and unsets a contextvar whenever it is active + match recompile_fn_with_transform(x, _maxray_walker_handler): + case Ok(x_trans): + # NOTE: x_trans now has _MAXRAY_TRANSFORMED field to True + if inspect.ismethod(x): + # Two cases: descriptor vs bound method + # TODO: handle callables and .__call__ patching + match x.__self__: + case type(): + # Descriptor + logger.warning( + f"monkey-patching descriptor method {x.__name__} on type {x.__self__}" + ) + parent_cls = x.__self__ + case _: + # Bound method + logger.warning( + f"monkey-patching bound method {x.__name__} on type {type(x.__self__)}" + ) + parent_cls = type(x.__self__) + + # Monkey-patching the methods. Probably unsafe and unsound + setattr(parent_cls, x.__name__, x_trans) + x_patched = getattr( + x.__self__, x.__name__ + ) # getattr turns class descriptors (@classmethod) into bound methods + + # We don't bother caching methods as they're monkey-patched + # SOUNDNESS: a package might manually keep references to __init__ around to later call them - but we'd just end up recompiling those as well + else: + x_patched = x_trans + _MAXRAY_FN_CACHE[x] = x_patched + x = x_patched + + case Err(e): + # Cache failures + _MAXRAY_FN_CACHE[x] = x + # Errors in functions that have been recursively compiled are unimportant + logger.trace(f"Failed to transform in walker handler: {e}") + + # 2. run the active hooks + global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True) + try: + for walk_hook in _MAXRAY_REGISTERED_HOOKS: + if not walk_hook.active_call_state.get(): + continue + + if ctx.fn_context.module in walk_hook.skip_modules: + continue + + # Set the writer active flag + write_active_token = walk_hook.writer_active_call_state.set(True) + if walk_hook.mutable: + x = walk_hook.impl_fn(x, ctx) + else: + walk_hook.impl_fn(x, ctx) + walk_hook.writer_active_call_state.reset(write_active_token) + finally: + _GLOBAL_WRITER_ACTIVE_FLAG.reset(global_write_active_token) + + return x + + +def maxray( + writer: Callable[[Any, NodeContext], Any], skip_modules=frozenset(), *, mutable=True +): + """ + A transform that recursively hooks into all further calls made within the function, so that `writer` will (in theory) observe every single expression evaluated by the Python interpreter occurring as part of the decorated function call. + + There are some limitations to be aware of: + - Be careful to avoid infinite recursion: the source code of the writer will not be transformed but it may call methods that have been monkey-patched that result in more calls to the writer function. + - Objects that are not yet fully initialised may not behave as expected - e.g. repr may throw an error because of a missing property + """ + + ACTIVE_FLAG = ContextVar(f"maxray_active for <{writer}>", default=False) + WRITER_ACTIVE_FLAG = ContextVar(f"writer_active for <{writer}>", default=False) + + def recursive_transform(fn): + _MAXRAY_REGISTERED_HOOKS.append( + W_erHook( + writer, + ACTIVE_FLAG, + WRITER_ACTIVE_FLAG, + skip_modules, + set(), + mutable=mutable, + ) + ) + + # Fixes `test_double_decorators_with_locals`: repeated transforms are broken because stuffing closures into locals doesn't work the second time around + if hasattr(fn, "_MAXRAY_TRANSFORMED"): + fn_transform = fn + else: + match recompile_fn_with_transform(fn, _maxray_walker_handler): + case Ok(fn_transform): + pass + case Err(err): + # Errors are only displayed at top-level, when the user has manually annotated a function with @xray or the like + logger.error(err) + return fn + + # BUG: We can't do @wraps if it's a callable instance, right? + @wraps(fn) + def fn_with_context_update(*args, **kwargs): + # already active on stack + if ACTIVE_FLAG.get(): + return fn_transform(*args, **kwargs) + + ACTIVE_FLAG.set(True) + try: + return fn_transform(*args, **kwargs) + finally: + ACTIVE_FLAG.set(False) + + fn_with_context_update._MAXRAY_TRANSFORMED = True + return fn_with_context_update + + return recursive_transform diff --git a/maxray/transforms.py b/maxray/transforms.py new file mode 100644 index 0000000..63e8e1b --- /dev/null +++ b/maxray/transforms.py @@ -0,0 +1,421 @@ +import ast +import inspect +import sys + +from textwrap import dedent +from pathlib import Path +from dataclasses import dataclass +from copy import deepcopy +from result import Result, Ok, Err + +from typing import Any, Callable + +from loguru import logger + + +_METHOD_MANGLE_NAME = "vdxivosjdovs_method" + + +def mangle_name(identifier): + raise NotImplementedError() + + +def unmangle_name(identifier): + raise NotImplementedError() + + +@dataclass +class FnContext: + impl_fn: Callable + name: str + module: str + source: str + source_file: str + # TODO: add location as well + + +@dataclass +class NodeContext: + id: str + """ + Identifier for the type of syntax node this event came from. For example: + - name/x + - call/foo + """ + + source: str + + fn_context: FnContext + """ + Properties of the function containing this node. + """ + + location: tuple[int, int, int, int] + + def __repr__(self): + return f"{self.fn_context.module}/{self.fn_context.name}/{self.id}" + + +class FnRewriter(ast.NodeTransformer): + def __init__( + self, transform_fn, fn_context: FnContext, *, instance_type: str | None + ): + """ + If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None. + """ + + self.transform_fn = transform_fn + self.fn_context = fn_context + self.instance_type = instance_type + + # the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions. + self.fn_count = 0 + + def is_method(self): + return self.instance_type is not None + + @staticmethod + def safe_unparse(node): + # workaround for https://github.com/python/cpython/issues/108469 (fixed in python 3.12) + try: + return ast.unparse(node) + except ValueError as e: + return "" + + @staticmethod + def is_private_class_name(identifier_name: str): + return ( + identifier_name.startswith("__") + and not identifier_name.endswith("__") + and identifier_name.strip("_") + ) + + def build_transform_node(self, node, label, node_source=None): + """ + Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`. + """ + if node_source is None: + node_source = self.safe_unparse(node) + + line_offset = self.fn_context.impl_fn.__code__.co_firstlineno - 2 + col_offset = 4 + context_node = ast.Call( + func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()), + args=[ + ast.Constant(label), + ast.Constant(node_source), + # Name is injected into the exec scope by `recompile_fn_with_transform` + ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()), + ast.Constant( + ( + line_offset + node.lineno, + line_offset + node.end_lineno, + node.col_offset + col_offset, + node.end_col_offset + col_offset, + ) + ), + ], + keywords=[], + ) + + return ast.Call( + func=ast.Name(id=self.transform_fn.__name__, ctx=ast.Load()), + args=[node, context_node], + keywords=[], + ) + + def visit_Name(self, node): + match node.ctx: + case ast.Load(): + # Variable is accessed + new_node = self.generic_visit(node) + case ast.Store(): + # Variable is assigned to + return node + case _: + logger.error(f"Unknown context {node.ctx}") + return node + + return self.build_transform_node(new_node, f"name/{node.id}") + + def visit_Attribute(self, node: ast.Attribute) -> Any: + """ + https://docs.python.org/3/reference/expressions.html#atom-identifiers + > Private name mangling: When an identifier that textually occurs in a class definition begins with two or more underscore characters and does not end in two or more underscores, it is considered a private name of that class. Private names are transformed to a longer form before code is generated for them. The transformation inserts the class name, with leading underscores removed and a single underscore inserted, in front of the name. For example, the identifier __spam occurring in a class named Ham will be transformed to _Ham__spam. This transformation is independent of the syntactical context in which the identifier is used. If the transformed name is extremely long (longer than 255 characters), implementation defined truncation may happen. If the class name consists only of underscores, no transformation is done. + """ + source_pre = self.safe_unparse(node) + + if self.is_method() and self.is_private_class_name(node.attr): + node.attr = f"_{self.instance_type}{node.attr}" + logger.warning("Replaced with mangled private name") + + if isinstance(node.ctx, ast.Load): + node = self.generic_visit(node) + node = self.build_transform_node( + node, f"attr/{node.attr}", node_source=source_pre + ) + return node + + def visit_Assign(self, node: ast.Assign) -> Any: + new_node = self.generic_visit(node) + assert isinstance(new_node, ast.Assign) + # node = new_node + # node.value = self.build_transform_node(new_node, f"assign/(multiple)") + return node + + def visit_Call(self, node): + node_pre = deepcopy(node) + source_pre = self.safe_unparse(node_pre) + + node = self.generic_visit(node) # mutates + + # the function/callable instance itself is observed by Name/Attribute/... nodes + + target = node.func + match target: + case ast.Name(): + logger.debug(f"Visiting call to function {target.id}") + case ast.Attribute(): + logger.debug(f"Visiting call to attribute {target.attr}") + + return ast.copy_location( + self.build_transform_node( + node, f"call/{source_pre}", node_source=source_pre + ), + node_pre, + ) + + id_key = f"{self.context_fn.__name__}/call/{source_pre}" + # Observes the *output* of the function + call_observer = ast.Call( + func=ast.Name(id=self.transform_fn.__name__, ctx=ast.Load()), + args=[node, ast.Constant({"id": id_key})], + keywords=[], + ) + return call_observer + + def visit_FunctionDef(self, node: ast.FunctionDef): + pre_node = deepcopy(node) + + self.fn_count += 1 + # Only overwrite the name of our "target function" + if self.fn_count == 1 and self.is_method(): + node.name = f"{node.name}_{_METHOD_MANGLE_NAME}_{node.name}" + + # TODO: we should replace decorators with a dynamic check for not belonging to our own `maxray` module + BANNED_DECORATIONS = {"maxray", "xray", "transform"} + node.decorator_list = [ + decor + for decor in node.decorator_list + if not ( + isinstance(decor, ast.Call) + and isinstance(decor.func, ast.Name) + and decor.func.id in BANNED_DECORATIONS + ) + ] + + # Removes type annotations from the call for safety as they're evaluated at definition-time rather than call-time + # This may not be needed now that locals are (usually) captured properly + for arg in node.args.args: + arg.annotation = None + + out = ast.copy_location(self.generic_visit(node), pre_node) + return out + + +_TRANSFORM_CACHE = {} + + +def get_fn_name(fn): + """ + Get a printable representation of the function for human-readable errors + """ + if hasattr(fn, "__name__"): + name = fn.__name__ + else: + try: + name = repr(fn) + except Exception: + name = "" + + return f"{name} @ {id(fn)}" + + +def recompile_fn_with_transform( + source_fn, transform_fn, ast_pre_callback=None, ast_post_callback=None +) -> Result[Callable, str]: + """ + Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code. + """ + try: + source = inspect.getsource(source_fn) + + # nested functions have excess indentation preventing compile; inspect.cleandoc(source) is an alternative + source = dedent(source) + + sourcefile = inspect.getsourcefile(source_fn) + module = inspect.getmodule(source_fn) + + # the way numpy implements its array hooks means it does its own voodoo code generation resulting in functions that have source code, but no corresponding source file + # e.g. the source file of `np.unique` is <__array_function__ internals> + if sourcefile is None or not Path(sourcefile).exists(): + return Err( + f"Non-existent source file ({sourcefile}) for function {get_fn_name(source_fn)}" + ) + + fn_ast = ast.parse(source) + except OSError: + return Err(f"No source code for function {get_fn_name(source_fn)}") + except TypeError: + return Err( + f"No source code for probable built-in function {get_fn_name(source_fn)}" + ) + + # TODO: use non-overridable __getattribute__ instead? + if not hasattr(source_fn, "__name__"): # Safety check against weird functions + return Err(f"There is no __name__ for function {get_fn_name(source_fn)}") + + if "super()" in source: + # TODO: we could replace calls to super() with super(__class__, self)? + return Err( + f"Function {get_fn_name(source_fn)} cannot be transformed because it calls super()" + ) + + match fn_ast: + case ast.Module(body=[ast.FunctionDef()]): + # Good + pass + case _: + return Err( + f"The targeted function {get_fn_name(source_fn)} does not correspond to a single `def` block so cannot be transformed safely!" + ) + + if ast_pre_callback is not None: + ast_pre_callback(fn_ast) + + fn_is_method = inspect.ismethod(source_fn) + if fn_is_method: + # Many potential unintended side-effects + match source_fn.__self__: + case type(): + # Descriptor + parent_cls = source_fn.__self__ + case _: + # Bound method + parent_cls = type(source_fn.__self__) + + fn_context = FnContext( + source_fn, source_fn.__name__, module.__name__, source, sourcefile + ) + transformed_fn_ast = FnRewriter( + transform_fn, + fn_context, + instance_type=parent_cls.__name__ if fn_is_method else None, + ).visit(fn_ast) + ast.fix_missing_locations(transformed_fn_ast) + + if ast_post_callback is not None: + ast_post_callback(transformed_fn_ast) + + scope = { + transform_fn.__name__: transform_fn, + NodeContext.__name__: NodeContext, + "_MAXRAY_FN_CONTEXT": fn_context, + } + + # Add class-private names to scope (though only should be usable as a default argument) + # TODO: should apply to all definitions within a class scope - so @staticmethod descriptors as well... + if fn_is_method: + scope.update( + { + name: val + for name, val in parent_cls.__dict__.items() + # TODO: BUG: ah... this excludes torch modules, right? + if not callable(val) + } + ) + + def extract_cell(cell): + try: + return cell.cell_contents + except ValueError: + # Cell is empty + logger.warning( + f"No contents for closure cell in function {get_fn_name(source_fn)} - this can happen with recursion" + ) + return None + + scope.update(vars(module)) + + if hasattr(source_fn, "__closure__") and source_fn.__closure__ is not None: + scope.update( + { + name: extract_cell(cell) + for name, cell in zip( + source_fn.__code__.co_freevars, source_fn.__closure__ + ) + } + ) + + if not fn_is_method and source_fn.__name__ in scope: + logger.warning( + f"Name {source_fn.__name__} already exists in scope for non-method" + ) + + try: + exec( + compile( + transformed_fn_ast, filename=f"<{source_fn.__name__}>", mode="exec" + ), + scope, + scope, + ) + except Exception as e: + logger.exception(e) + logger.error( + f"Failed to compile function {source_fn.__name__} in its module {module}" + ) + + # FALLBACK: in numpy.core.numeric, they define `@set_module` that rewrites __module__ so inspect gives us the wrong module to correctly re-execute the def in + # sourcefile is still correct so let's try use `sys.modules` + + file_to_modules = { + getattr(mod, "__file__", None): mod for mod in sys.modules.values() + } + if sourcefile in file_to_modules: + scope.update(vars(file_to_modules[sourcefile])) + try: + exec( + compile( + transformed_fn_ast, + filename=f"<{source_fn.__name__}>", + mode="exec", + ), + scope, + scope, + # closure=fn.__closure__, + ) + except Exception as e: + logger.exception(e) + return Err( + f"Re-def of function {get_fn_name(source_fn)} in its source file module at {sourcefile} also failed" + ) + else: + return Err( + f"Failed to re-def function {get_fn_name(source_fn)} and its source file {sourcefile} was not found in sys.modules" + ) + + if fn_is_method: + transformed_fn = scope[ + f"{source_fn.__name__}_{_METHOD_MANGLE_NAME}_{source_fn.__name__}" + ] + else: + transformed_fn = scope[source_fn.__name__] + + # unmangle the name again - it's possible some packages might use __name__ internally for registries and whatnot + transformed_fn.__name__ = source_fn.__name__ + + # way to keep track of which functions we've already transformed + transformed_fn._MAXRAY_TRANSFORMED = True + + return Ok(transformed_fn) diff --git a/maxray/walkers.py b/maxray/walkers.py new file mode 100644 index 0000000..aaa6ced --- /dev/null +++ b/maxray/walkers.py @@ -0,0 +1,11 @@ +from loguru import logger + + +def dbg(x, ctx): + try: + x_repr = repr(x) + except Exception: + x_repr = "" + + logger.debug(f"{ctx} :: {x_repr}") + return x diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f0847ef --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.poetry] +name = "maxray" +version = "0.1.0" +description = "" +authors = ["blepabyte <255@blepabyte.me>"] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.11" +result = "^0.16.1" +loguru = "^0.7.2" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..25e76d2 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,322 @@ +from maxray import transform, xray, maxray +from maxray.walkers import dbg + +from contextlib import contextmanager +from dataclasses import dataclass +import functools + + +def test_basic(): + @transform(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def f(x): + return x + + assert f(3) == 4 + + +def test_type_hints(): + from typing import Any + + @transform(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def f(x: Any): + return x + + assert f(3) == 4 + + +def test_closure_capture(): + z = 1 + + @transform(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def f(x): + return x + z + + assert f(3) == 6 + + +def test_closure_capture_mutate(): + z = [] + + @transform(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def f(x): + z.append(x) + + f(1) + f(2) + assert z == [2, 3] + + +GLOB_CONST = 5 + + +def test_global_capture(): + @transform(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def g(x): + return x + GLOB_CONST + + assert g(3) == 10 + + +def test_nested_def(): + def outer(): + z = [] + + @transform(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def g(x): + z.append(x) + return x + + return g + + assert outer()(3) == 4 + assert outer()(3) == 4 + + +def test_recursive(): + @transform(dbg) + def countdown(n): + if n > 0: + return 1 + countdown(n - 1) + else: + return 0 + + assert countdown(5) == 5 + + +def test_fib(): + @transform(dbg) + def fib(n: int): + a, b = 1, 1 + for i in range(n): + a, b = b, a + b + return b + + fib(10) + + +def test_decorated(): + def outer(z): + def wrapper(f): + @functools.wraps(f) + def inner(x): + print(z) + return f(x) + 1 + + return inner + + return wrapper + + def middle(x): + @outer("nope") + def f(x): + return x + + return f(x) + + @xray(dbg) + def f(x): + return middle(x) + + assert f(3) == 4 + + +def test_inspect_signature(): + from inspect import signature + + sss = xray(dbg)(signature) + print(sss(lambda x: x)) + + +@contextmanager +def oop(): + try: + yield "foo" + finally: + pass + + +def test_contextmanager(): + @xray(dbg) + def f(): + with oop() as x: + pass + return x + + assert f() == "foo" + + +def test_property_access(): + @dataclass + class A: + x: int + y: float + + obj = A(1, 2.3) + + @maxray(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def g(): + return obj.x + + assert g() == 2 + + +def test_method(): + class A: + def foo(self): + x = 1 + return str(x) + + @maxray(lambda x, ctx: (print(x), x + 1)[-1] if isinstance(x, int) else x) + def g(): + a = A() + return a.foo() + + assert g() == "2" + + +def test_recursive_self_repr(): + """ + RecursionError: maximum recursion depth exceeded while getting the str of an object + + The problem is that as part of `dbg`, uses of `self` trigger *another* `dbg` call and so on... + """ + + class X: + def bar(self): + pass + + def __repr__(self): + # The problem was that self is callable + self + return "X" + + @xray(dbg) + def inner(): + x = X() + print(X.__repr__(x)) + print(f"{X()}") + return 1 + + assert inner() == 1 + + +def test_class_method_default_arg_using_local(): + """ + Ran into this with `_PosixFlavour` in `pathlib.py` (`splitroot`, line 8) + """ + + class Flavour: + sep = " " + + def groot(self, sep=sep): + return sep + + @xray(dbg) + def buggy(): + f = Flavour() + z = f.groot() + a = 1 + 2 + + buggy() + + +def test_method_scope_overwrite(): + """ + So many edge cases... + + So the difference between method definitions and standard function defs is that for fn defs, the name of the function itself is made available in the calling scope (e.g. for recursion - seemingly stored as a closure within nested defs). + + For methods, we have to mangle the name so that in the exec-ed function, we can correctly reference any coexisting globally def-ed function of the same name. + """ + + def OoO(): + return 1 + + class X: + def OoO(self): + return OoO() + 1 + + @xray(dbg) + def bad(): + x = X() + return x.OoO() + + assert bad() == 2 + assert OoO() == 1 + + +def test_private_name_mangled(): + class X: + def __init__(self): + self.__next() + + def __next(self): + print("hi") + + def g(self): + self.__next() + + @xray(dbg) + def bad(): + x = X() + x.g() + + bad() + + +def test_super(): + class A: + def to(self): + return 1 + + class B(A): + def to(self): + return super().to() + 1 + + @maxray(dbg) + def oop(): + b = B() + return b.to() + + assert oop() == 2 + + +def test_closure_reuse(): + x = [] + + @maxray(dbg) + def foo(): + # nonlocal x + x.append(1) + + @maxray(dbg) + def bar(): + x.append(2) + + foo() + bar() + assert len(x) == 2 + + +def test_double_decorators_with_locals(): + x = [] + + @maxray(dbg) + @maxray(dbg) + def foo(): + # nonlocal x + x.append(1) + + foo() + + +def test_xray_immutable(): + @maxray(lambda x, ctx: x * 10 if isinstance(x, float) else x) + @xray(lambda x, ctx: x + 1 if isinstance(x, int) else x) + def foo(): + # Currently assumes that literals/constants are not wrapped (they're uninteresting anyways) + x = 1 + y = 2.0 + return x, y + + assert foo() == (1, 20.0)