Skip to content

Commit

Permalink
fix: forbid metaclasses in transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Sep 1, 2024
1 parent 869db4c commit c3cafaf
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 45 deletions.
92 changes: 59 additions & 33 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps, _lru_cache_wrapper
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TypeVar
from result import Result, Ok, Err

import os
Expand Down Expand Up @@ -58,29 +58,32 @@ def xray(walker, **kwargs):


_GLOBAL_SKIP_MODULES = {
# Builtins/used internally
"builtins", # duh... object.__init__ is a wrapper_descriptor
# Don't patch ourself!
"maxray",
# Nor core parts of the language or modules we use internally
"builtins",
"ctypes",
"importlib",
"inspect",
"typing",
"ast",
"pathlib",
"uuid", # used for generating _MAXRAY_TRANSFORM_ID
# Uninteresting standard library modules
# TODO: probably just skip most of the entire Python standard library...
"abc", # excessive inheritance and super calls in scikit-learn
"inspect", # don't want to screw this module up
"pathlib", # internally used in transform for checking source file exists
"re", # internals of regexp have a lot of uninteresting step methods
"copy", # pytorch spends so much time in here
"functools", # partialmethod causes mis-bound `self` in TQDM
"uuid", # used for generating _MAXRAY_TRANSFORM_ID
# "asyncio",
"collections",
# Libraries that are too weird to correctly patch
"pytest",
# Libraries possibly not fully working yet
"logging", # global _loggerClass probably causes problems
"typing",
"ast",
"importlib",
"ctypes",
# Libraries not fully working yet
"functools", # partialmethod causes mis-bound `self` in TQDM
"loguru", # used internally in transform - accidental patching will cause inf recursion
"urllib3",
"aiohttp", # something icky about web.RouteTableDef() when route decorated from submodule
"maxray",
"git",
}
# TODO: probably just skip the entire Python standard library...


@dataclass
Expand Down Expand Up @@ -120,6 +123,13 @@ def descend_allowed(x, ctx: NodeContext):
method_wrapper = type(type.__call__.__get__(type))


def base_module(obj):
qual_module = getattr(obj, "__module__", None)
if qual_module is None:
return ""
return qual_module.split(".")[0]


def transform_precheck(x, ctx: NodeContext):
if not callable(x):
return False
Expand All @@ -145,9 +155,8 @@ def transform_precheck(x, ctx: NodeContext):
pass

# __module__: The name of the module the function was defined in, or None if unavailable.
if getattr(x, "__module__", None) in _GLOBAL_SKIP_MODULES:
if base_module(x) in _GLOBAL_SKIP_MODULES:
return False
# TODO: also check ctx.fn_context.module?

return True

Expand All @@ -156,17 +165,10 @@ def callable_allowed_for_transform(x, ctx: NodeContext):
if not transform_precheck(x, ctx):
return False

module_path = ctx.fn_context.module.split(".")
if module_path[0] in _GLOBAL_SKIP_MODULES:
return False
# TODO: deal with nonhashable objects and callables and other exotic types properly
return (
callable(x) # super() has getset_descriptor instead of proper __dict__
and hasattr(x, "__dict__")
# and "_MAXRAY_TRANSFORMED" not in x.__dict__
# since we no longer rely on caching, we don't need to check for __hash__
# and callable(getattr(x, "__hash__", None))
and getattr(type(x), "__module__", None) not in {"ctypes"}
and base_module(type(x)) not in {"ctypes"}
and (
inspect.isfunction(x)
or inspect.ismethod(x)
Expand All @@ -179,6 +181,8 @@ def instance_allowed_for_transform(x, ctx: NodeContext):
"""
Checks if x is a type with dunder methods can be correctly transformed.
"""
# Forbid metaclasses as they can arbitrarily modify and replace functions
# related SOUNDNESS BUG: function wrapper not applied via decorator - track by patching functools.wraps or tracing?
if type(x) is not type:
# Filter out weird stuff that would get through with an isinstance check
# e.g. loguru: _GeneratorContextManager object is not an iterator (might be a separate bug)
Expand All @@ -190,6 +194,7 @@ def instance_allowed_for_transform(x, ctx: NodeContext):
return True


# TODO: this doesn't really make sense: only a single set of traversal settings can be sanely "applied" at a time
def _inator_inator(restrict_modules: Optional[list[str]] = None):
"""
Control configuration to determine which, and how, callables get recursively transformed.
Expand Down Expand Up @@ -261,7 +266,15 @@ def _maxray_walker_handler(x, ctx: NodeContext, autotransform=True):
with_fn = FunctionStore.get(x_trans._MAXRAY_TRANSFORM_ID)

# This does not apply when accessing X.method - only X().method
if inspect.ismethod(x):
if (
inspect.ismethod(x)
and with_fn.data.method_info.is_inspect_method is not True
):
logger.warning(
"Inconsistent method status - probable result of wrapping or metaclass shenanigans"
)
x_patched = x
elif inspect.ismethod(x):
# if with_fn.method is not None:
# Two cases: descriptor vs bound method
match x.__self__:
Expand Down Expand Up @@ -344,6 +357,9 @@ def _maxray_walker_handler(x, ctx: NodeContext, autotransform=True):
return _maxray_walker_handler


T = TypeVar("T", bound=Callable)


def maxray(
writer: Callable[[Any, NodeContext], Any],
skip_modules=frozenset(),
Expand All @@ -353,7 +369,7 @@ def maxray(
pass_scope=False,
initial_scope={},
assume_transformed=False,
):
) -> Callable[[T], T]:
"""
A transform that recursively hooks into all further calls made within the function, so that `writer` will (in theory) observe every single expression evaluated by the Python interpreter occurring as part of the decorated function call.
Expand Down Expand Up @@ -391,7 +407,7 @@ def recursive_transform(fn):
)

# Fixes `test_double_decorators_with_locals`: repeated transforms are broken because stuffing closures into locals doesn't work the second time around
if hasattr(fn, "_MAXRAY_TRANSFORMED") or assume_transformed:
if hasattr(fn, "_MAXRAY_TRANSFORM_ID") or assume_transformed:
fn_transform = fn
else:
match recompile_fn_with_transform(
Expand Down Expand Up @@ -427,9 +443,19 @@ def fn_with_context_update(*args, **kwargs):
finally:
ACTIVE_FLAG.reset(prev_token)

fn_with_context_update._MAXRAY_TRANSFORMED = True
# TODO: set correctly everywhere
# fn_with_context_update._MAXRAY_TRANSFORM_ID = ...
fn_with_context_update._MAXRAY_TRANSFORM_ID = fn_transform._MAXRAY_TRANSFORM_ID

# If we're given a bound method we need to return a bound method on the same instance
# Can only happen via xray(...)(some_method), not when applied via decorator
if inspect.ismethod(fn):
parent_cls = fn.__self__
if not isinstance(parent_cls, type):
parent_cls = type(parent_cls)

fn_with_context_update = fn_with_context_update.__get__(
fn.__self__, parent_cls
)

return fn_with_context_update

return recursive_transform
return recursive_transform # type: ignore
6 changes: 6 additions & 0 deletions maxray/function_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def validate(
reason = "Unable to find source code of function"
elif (not fd.source_file) or (not Path(fd.source_file).exists()):
reason = "Unable to find source file of function"
elif fd.method_info.is_inspect_method and (
type(fd.method_info.instance_cls) is not type
or type(fd.method_info.defined_on_cls) is not type
):
reason = "Method is on a metaclass"
elif hasattr(fn, "_MAXRAY_TRANSFORM_ID"):
reason = (
f"Function has already been transformed ({fn._MAXRAY_TRANSFORM_ID})"
Expand Down Expand Up @@ -281,6 +286,7 @@ def push(fd: CompiledFunction | ErroredFunction):
def collect():
# TODO: Disable locking by default? (design a better context interface?)
with FunctionStore.lock:
# TODO: use a schema to handle case of zero rows
return pa.table(
pa.array(
[
Expand Down
4 changes: 2 additions & 2 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class NodeContext:
end_col is the exclusive endpoint of the range
"""

local_scope: Any = None
local_scope: Optional[dict] = None
"When `pass_scope` is True, contains the output of `builtins.locals()` evaluated in the scope of the source expression"

caller_id: Any = None

Expand Down Expand Up @@ -859,7 +860,6 @@ def extract_cell(cell):
transformed_fn.__qualname__ = source_fn.__qualname__

# way to keep track of which functions we've already transformed
transformed_fn._MAXRAY_TRANSFORMED = True
transformed_fn._MAXRAY_TRANSFORM_ID = with_source_fn.compile_id
with_source_fn.mark_compiled(transformed_fn)

Expand Down
42 changes: 32 additions & 10 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def increment_ints_by_one(x, ctx):
return x


def convert_ints_to_str(x, ctx):
match x:
case int():
return str(x)
case _:
return x


def test_basic():
@transform(increment_ints_by_one)
def f(x):
Expand All @@ -41,11 +49,11 @@ def f(x: Any):
def test_closure_capture():
z = 1

@transform(increment_ints_by_one)
@transform(convert_ints_to_str)
def f(x):
return x + z

assert f(3) == 7
assert f(3) == "31"


def test_closure_capture_mutate():
Expand All @@ -64,11 +72,11 @@ def f(x):


def test_global_capture():
@transform(increment_ints_by_one)
@transform(convert_ints_to_str)
def g(x):
return x + GLOB_CONST

assert g(3) == 11
assert g(3) == "35"


def test_nested_def():
Expand Down Expand Up @@ -462,14 +470,14 @@ def uh():


def test_junk_annotations():
@maxray(increment_ints_by_one)
@maxray(convert_ints_to_str)
def outer():
def inner(x: ASDF = 0, *, y: SDFSDF = 100) -> AAAAAAAAAAA:
return x + y

return inner(2)

assert outer() == 107
assert outer() == "2100"


def test_call_counts():
Expand Down Expand Up @@ -519,13 +527,11 @@ def empty_returns():


def test_scope_passed():
found_scope = None
found_scope = {}

def get_scope(x, ctx):
nonlocal found_scope
if ctx.local_scope is not None:
assert found_scope is None
found_scope = ctx.local_scope
found_scope.update(ctx.local_scope)
return x

@xray(get_scope, pass_scope=True)
Expand Down Expand Up @@ -722,3 +728,19 @@ def check_isna():
return Framed.isna(f)

assert check_isna() == 2


def test_qualified_init():
class A:
def __init__(self):
self.a_prop = 101

class B(A):
def __init__(self):
A.__init__(self)

@xray(dbg)
def get_prop():
return B().a_prop

assert get_prop() == 101

0 comments on commit c3cafaf

Please sign in to comment.