Skip to content

Commit

Permalink
fix(transform): order scope layers to prevent namespace pollution
Browse files Browse the repository at this point in the history
Fixes bug where np.load errors with "numpy.AxisError: axis 6 is out of bounds for array of dimension 0" because `min` was dispatching to `np.min` due to prior `@set_module` decorator.
  • Loading branch information
blepabyte committed Jun 17, 2024
1 parent 911fadc commit 57ab881
Showing 1 changed file with 43 additions and 19 deletions.
62 changes: 43 additions & 19 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def recompile_fn_with_transform(
return Err("Cannot safely recompile lambda functions")

# handle `functools.wraps`
# TODO: multiple layers of nesting?
if hasattr(source_fn, "__wrapped__"):
# SOUNDNESS: failure when decorators aren't applied at the definition site (will look for the original definition, ignoring any transformations that have been applied before the wrap but after definition)
source_fn = source_fn.__wrapped__
Expand Down Expand Up @@ -579,27 +580,32 @@ def patch_mro(super_type: super):

return super_type

scope = {
transform_fn.__name__: transform_fn,
NodeContext.__name__: NodeContext,
"_MAXRAY_FN_CONTEXT": fn_context,
"_MAXRAY_CALL_COUNTER": fn_call_counter,
"_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with,
"_MAXRAY_BUILTINS_LOCALS": locals,
"_MAXRAY_PATCH_MRO": patch_mro,
scope_layers = {
"core": {
transform_fn.__name__: transform_fn,
NodeContext.__name__: NodeContext,
"_MAXRAY_FN_CONTEXT": fn_context,
"_MAXRAY_CALL_COUNTER": fn_call_counter,
"_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with,
"_MAXRAY_BUILTINS_LOCALS": locals,
"_MAXRAY_PATCH_MRO": patch_mro,
},
"override": initial_scope,
"class_local": {},
"module": {},
"closure": {},
}
scope.update(initial_scope)

# BUG: this will NOT work with threading - could use ContextVar if no performance impact?
def set_temp(val, name: str):
scope[name] = val
return val

scope["_MAXRAY_SET_TEMP"] = set_temp
scope_layers["core"]["_MAXRAY_SET_TEMP"] = set_temp
# Add class-private names to scope (though only should be usable as a default argument)
# TODO: should apply to all definitions within a class scope - so @staticmethod descriptors as well...
if fn_is_method:
scope.update(
scope_layers["class_local"].update(
{
name: val
for name, val in parent_cls.__dict__.items()
Expand All @@ -618,10 +624,10 @@ def extract_cell(cell):
)
return None

scope.update(vars(module))
scope_layers["module"].update(vars(module))

if hasattr(source_fn, "__closure__") and source_fn.__closure__ is not None:
scope.update(
scope_layers["closure"].update(
{
name: extract_cell(cell)
for name, cell in zip(
Expand All @@ -630,12 +636,21 @@ def extract_cell(cell):
}
)

if not fn_is_method and source_fn.__name__ in scope:
logger.warning(
f"Name {source_fn.__name__} already exists in scope for non-method"
)

try:
# TODO: this might be slow
scope = {
**scope_layers["core"],
**scope_layers["class_local"],
**scope_layers["module"],
**scope_layers["closure"],
**scope_layers["override"],
}

if not fn_is_method and source_fn.__name__ in scope:
logger.warning(
f"Name {source_fn.__name__} already exists in scope for non-method"
)

exec(
compile(
transformed_fn_ast, filename=f"<{source_fn.__name__}>", mode="exec"
Expand All @@ -661,7 +676,15 @@ def extract_cell(cell):
getattr(mod, "__file__", None): mod for mod in sys.modules.values()
}
if sourcefile in file_to_modules:
scope.update(vars(file_to_modules[sourcefile]))
# Re-executing in a different module: re-declare scope without the previous module (otherwise we get incorrect behaviour like `min` being replaced with `np.amin` in `np.load`)
scope = {
**scope_layers["core"],
**scope_layers["class_local"],
**vars(file_to_modules[sourcefile]),
**scope_layers["closure"],
**scope_layers["override"],
}

try:
exec(
compile(
Expand Down Expand Up @@ -705,6 +728,7 @@ def extract_cell(cell):
return Ok(transformed_fn)


# TODO: probably better to modify the generated code directly instead of relying on a wrapper...
def count_calls_with(counter: ContextVar):
def inner(fn):
# TODO: synchronisation/context?
Expand Down

0 comments on commit 57ab881

Please sign in to comment.