Skip to content

Commit

Permalink
fix: transform control flow and relative package import
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Aug 15, 2024
1 parent b70b3ee commit aaee9f0
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 147 deletions.
289 changes: 159 additions & 130 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps, _lru_cache_wrapper
from typing import Any, Callable
from typing import Any, Callable, Optional
from result import Result, Ok, Err

import os
Expand Down Expand Up @@ -114,8 +114,6 @@ def descend_allowed(x, ctx: NodeContext):
# We don't want to recompile the same function over and over - so our cache needs to be global
# TODO: add cache by source code location?
_MAXRAY_FN_CACHE = dict()
_MAXRAY_FN_FAILED_CACHE = WeakSet()


wrapper_descriptor = type(object.__init__)
method_wrapper = type(type.__call__.__get__(type))
Expand All @@ -135,7 +133,12 @@ def transform_precheck(x, ctx: NodeContext):
try:
# Avoid calling getattr since things like DataFrames override it (and cause recursion errors)
x.__getattribute__("_MAXRAY_NOTRANSFORM")
# object.__getattribute__(x, "_MAXRAY_NOTRANSFORM")
return False
except AttributeError:
pass

try:
x.__getattribute__("_MAXRAY_TRANSFORM_ID")
return False
except AttributeError:
pass
Expand Down Expand Up @@ -189,11 +192,7 @@ def instance_init_allowed_for_transform(x, ctx: NodeContext):
except AttributeError:
return False

return (
getattr(x, "__module__", None) not in {"ctypes"}
# and hasattr(x, "__init__")
# and not hasattr(x, "_MAXRAY_TRANSFORMED")
)
return getattr(x, "__module__", None) not in {"ctypes"}


def instance_call_allowed_for_transform(x, ctx: NodeContext):
Expand All @@ -213,144 +212,174 @@ def instance_call_allowed_for_transform(x, ctx: NodeContext):
except AttributeError:
return False

return (
getattr(x, "__module__", None) not in {"ctypes"}
# and hasattr(x, "__call__")
# and not hasattr(x, "_MAXRAY_TRANSFORMED")
)
return getattr(x, "__module__", None) not in {"ctypes"}


def _maxray_walker_handler(x, ctx: NodeContext):
# 1. logic to recursively patch callables
# 1a. special-case callables: __init__ and __call__
if instance_init_allowed_for_transform(x, ctx):
match recompile_fn_with_transform(
x.__init__,
_maxray_walker_handler,
special_use_instance_type=x,
triggered_by_node=ctx,
):
case Ok(init_patch):
logger.debug(f"Patching __init__ for class {x}")
setattr(x, "__init__", init_patch)
case Err(_err):
set_property_on_functionlike(x.__init__, "_MAXRAY_NOTRANSFORM", True)

elif instance_call_allowed_for_transform(x, ctx):
match recompile_fn_with_transform(
x.__call__,
_maxray_walker_handler,
special_use_instance_type=x,
triggered_by_node=ctx,
):
case Ok(call_patch):
logger.debug(f"Patching __call__ for class {x}")
setattr(x, "__call__", call_patch)
case Err(_err):
set_property_on_functionlike(x.__call__, "_MAXRAY_NOTRANSFORM", True)

# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
elif callable_allowed_for_transform(x, ctx):
# We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload
if x in _MAXRAY_FN_FAILED_CACHE:
def _inator_inator(restrict_modules: Optional[list[str]] = None):
"""
Control configuration to determine which, and how, callables get recursively transformed.
"""

def callable_precheck(x, ctx: NodeContext):
"""
Is x a "normal" pure-Python function that we can safely substitute with its transformed version?
"""
raise NotImplementedError()

def instance_precheck(x, ctx: NodeContext):
raise NotImplementedError()

def _maxray_walker_handler(x, ctx: NodeContext, autotransform=True):
# 1. logic to recursively patch callables
# 1a. special-case callables: __init__ and __call__
if not autotransform:
# Disables recursively attaching handler to all nested calls
pass
if x in _MAXRAY_FN_CACHE:
x = _MAXRAY_FN_CACHE[x]
elif not descend_allowed(x, ctx):
# user-defined filters for which nodes (not) to descend into

x_def_module = getattr(x, "__module__", "")
if x_def_module is None:
x_def_module = ""
if restrict_modules is not None and not any(
x_def_module.startswith(mod) for mod in restrict_modules
):
pass
else:
elif instance_init_allowed_for_transform(x, ctx):
match recompile_fn_with_transform(
x, _maxray_walker_handler, triggered_by_node=ctx
x.__init__,
_maxray_walker_handler,
special_use_instance_type=x,
triggered_by_node=ctx,
):
case Ok(x_trans):
# NOTE: x_trans now has _MAXRAY_TRANSFORMED field to True
with_fn = FunctionStore.get(x_trans._MAXRAY_TRANSFORM_ID)

# This does not apply when accessing X.method - only X().method
if inspect.ismethod(x):
# if with_fn.method is not None:
# Two cases: descriptor vs bound method
match x.__self__:
case type():
# Descriptor
logger.debug(
f"monkey-patching descriptor method {x.__name__} on type {x.__self__}"
case Ok(init_patch):
logger.debug(f"Patching __init__ for class {x}")
setattr(x, "__init__", init_patch)
case Err(_err):
set_property_on_functionlike(
x.__init__, "_MAXRAY_NOTRANSFORM", True
)

elif instance_call_allowed_for_transform(x, ctx):
match recompile_fn_with_transform(
x.__call__,
_maxray_walker_handler,
special_use_instance_type=x,
triggered_by_node=ctx,
):
case Ok(call_patch):
logger.debug(f"Patching __call__ for class {x}")
setattr(x, "__call__", call_patch)
case Err(_err):
set_property_on_functionlike(
x.__call__, "_MAXRAY_NOTRANSFORM", True
)

# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
elif callable_allowed_for_transform(x, ctx):
# We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload
if x in _MAXRAY_FN_CACHE:
x = _MAXRAY_FN_CACHE[x]
elif not descend_allowed(x, ctx):
# TODO: deprecate
# user-defined filters for which nodes (not) to descend into
pass
else:
# TODO: fixup control flow
match recompile_fn_with_transform(
x, _maxray_walker_handler, triggered_by_node=ctx
):
case Ok(x_trans):
# NOTE: x_trans now has _MAXRAY_TRANSFORMED field to True
with_fn = FunctionStore.get(x_trans._MAXRAY_TRANSFORM_ID)

# This does not apply when accessing X.method - only X().method
if inspect.ismethod(x):
# if with_fn.method is not None:
# Two cases: descriptor vs bound method
match x.__self__:
case type():
# Descriptor
logger.debug(
f"monkey-patching descriptor method {x.__name__} on type {x.__self__}"
)
parent_cls = x.__self__
case _:
# Bound method
logger.debug(
f"monkey-patching bound method {x.__name__} on type {type(x.__self__)}"
)
parent_cls = type(x.__self__)

self_cls = parent_cls
if with_fn.data.method_info.defined_on_cls is not None:
parent_cls = with_fn.data.method_info.defined_on_cls

# Monkey-patching the methods. Probably unsafe and unsound
# Descriptor guide: https://docs.python.org/3/howto/descriptor.html

# Sanity check: check that our patch target is identical to the unbound version of the method (to prevent patching on the wrong class)
supposed_x = getattr(parent_cls, x.__name__, None)
if hasattr(supposed_x, "__func__"):
supposed_x = supposed_x.__func__
if supposed_x is x.__func__ and supposed_x is not None:
setattr(parent_cls, x.__name__, x_trans)
x_patched = x_trans.__get__(x.__self__, self_cls)
else:
# Because any function can be assigned as a member of the class with an arbitrary name...
logger.warning(
"Could not monkey-patch because instance is incorrect"
)
parent_cls = x.__self__
case _:
# Bound method
logger.debug(
f"monkey-patching bound method {x.__name__} on type {type(x.__self__)}"
set_property_on_functionlike(
x, "_MAXRAY_NOTRANSFORM", True
)
parent_cls = type(x.__self__)

self_cls = parent_cls
if with_fn.data.method_info.defined_on_cls is not None:
parent_cls = with_fn.data.method_info.defined_on_cls

# Monkey-patching the methods. Probably unsafe and unsound
# Descriptor guide: https://docs.python.org/3/howto/descriptor.html

# Sanity check: check that our patch target is identical to the unbound version of the method (to prevent patching on the wrong class)
supposed_x = getattr(parent_cls, x.__name__, None)
if hasattr(supposed_x, "__func__"):
supposed_x = supposed_x.__func__
if supposed_x is x.__func__ and supposed_x is not None:
setattr(parent_cls, x.__name__, x_trans)
x_patched = x_trans.__get__(x.__self__, self_cls)
x_patched = x

# We don't bother caching methods as they're monkey-patched
# SOUNDNESS: a package might manually keep references to __init__ around to later call them - but we'd just end up recompiling those as well
else:
# Because any function can be assigned as a member of the class with an arbitrary name...
logger.warning(
"Could not monkey-patch because instance is incorrect"
)
set_property_on_functionlike(x, "_MAXRAY_NOTRANSFORM", True)
x_patched = x

# We don't bother caching methods as they're monkey-patched
# SOUNDNESS: a package might manually keep references to __init__ around to later call them - but we'd just end up recompiling those as well
else:
x_patched = x_trans
_MAXRAY_FN_CACHE[x] = x_patched
x = x_patched

case Err(e):
# Speedup by not trying to recompile (getsource involves filesystem lookup) the same bad function over and over
set_property_on_functionlike(x, "_MAXRAY_NOTRANSFORM", True)
logger.warning(
f"Failed to transform in walker handler: {e} {x.__qualname__}"
)
x_patched = x_trans
_MAXRAY_FN_CACHE[x] = x_patched
x = x_patched

case Err(e):
# Speedup by not trying to recompile (getsource involves filesystem lookup) the same bad function over and over
set_property_on_functionlike(x, "_MAXRAY_NOTRANSFORM", True)
logger.warning(
f"Failed to transform in walker handler: {e} {x.__qualname__}"
)

# We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell
# This happens *after* checking and patching callables to still allow for explicitly patching a callable/method by calling this handler
if _GLOBAL_WRITER_ACTIVE_FLAG.get():
return x

# 2. run the active hooks
global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True)
try:
for walk_hook in _MAXRAY_REGISTERED_HOOKS:
# Our recompiled fn sets and unsets a contextvar whenever it is active
if not walk_hook.active_call_state.get():
continue

# Set the writer active flag
write_active_token = walk_hook.writer_active_call_state.set(True)
if walk_hook.mutable:
x = walk_hook.impl_fn(x, ctx)
else:
walk_hook.impl_fn(x, ctx)
walk_hook.writer_active_call_state.reset(write_active_token)
finally:
_GLOBAL_WRITER_ACTIVE_FLAG.reset(global_write_active_token)

# We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell
# This happens *after* checking and patching callables to still allow for explicitly patching a callable/method by calling this handler
if _GLOBAL_WRITER_ACTIVE_FLAG.get():
return x

# 2. run the active hooks
global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True)
try:
for walk_hook in _MAXRAY_REGISTERED_HOOKS:
# Our recompiled fn sets and unsets a contextvar whenever it is active
if not walk_hook.active_call_state.get():
continue

# Set the writer active flag
write_active_token = walk_hook.writer_active_call_state.set(True)
if walk_hook.mutable:
x = walk_hook.impl_fn(x, ctx)
else:
walk_hook.impl_fn(x, ctx)
walk_hook.writer_active_call_state.reset(write_active_token)
finally:
_GLOBAL_WRITER_ACTIVE_FLAG.reset(global_write_active_token)

return x
return _maxray_walker_handler


def maxray(
writer: Callable[[Any, NodeContext], Any],
skip_modules=frozenset(),
*,
root_inator=_inator_inator(),
mutable=True,
pass_scope=False,
initial_scope={},
Expand Down Expand Up @@ -398,7 +427,7 @@ def recursive_transform(fn):
else:
match recompile_fn_with_transform(
fn,
_maxray_walker_handler,
root_inator,
override_scope=caller_locals,
pass_scope=pass_scope,
is_maxray_root=True,
Expand Down
Loading

0 comments on commit aaee9f0

Please sign in to comment.