Skip to content

Commit

Permalink
fix: re-ordering, and guard against exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Jun 23, 2024
1 parent 7e6b9d8 commit c9bc661
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 44 deletions.
13 changes: 9 additions & 4 deletions maxray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from typing import Any, Callable
from result import Result, Ok, Err

import os

from loguru import logger

# Avoid logspam for users of the library
logger.disable("maxray")

if not os.environ.get("MAXRAY_LOG_LEVEL"):
# Avoid logspam for users of the library
logger.disable("maxray")


def _set_logging(enabled: bool):
Expand Down Expand Up @@ -116,8 +120,9 @@ def callable_allowed_for_transform(x, ctx: NodeContext):
return False
# TODO: deal with nonhashable objects and callables and other exotic types properly
return (
not hasattr(x, "_MAXRAY_TRANSFORMED")
and callable(x)
callable(x) # super() has getset_descriptor instead of proper __dict__
and hasattr(x, "__dict__")
and "_MAXRAY_TRANSFORMED" not in x.__dict__
and callable(getattr(x, "__hash__", None))
and getattr(type(x), "__module__", None) not in {"ctypes"}
and (
Expand Down
82 changes: 44 additions & 38 deletions maxray/capture/reprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Mapping
import sys

from loguru import logger


def type_name(x):
try:
Expand All @@ -12,41 +14,45 @@ def type_name(x):


def structured_repr(x):
match x:
case np.ndarray():
size = "x".join(str(dim) for dim in x.shape)
return f"[{size}]"

case list():
if x:
return f"list[{structured_repr(x[0])} x {len(x)}]"

else:
return "list[]"

case tuple():
if len(x) <= 3:
el_reprs = ", ".join(structured_repr(el) for el in x)
else:
el_reprs = ", ".join(structured_repr(el) for el in x[:3]) + ", ..."
return f"({el_reprs})"

case bool():
return str(x)

case Mapping() if (keys := list(x)) and isinstance(keys[0], str):
inner_repr = ", ".join(f"{k}: {structured_repr(x[k])}" for k in keys)
return f"{type_name(x)} {{{inner_repr}}}"

if "torch" in sys.modules:
torch = sys.modules["torch"]
if isinstance(x, torch.Tensor):
size = "x".join(str(dim) for dim in x.shape)
return f"[{size}]"

if "awkward" in sys.modules:
ak = sys.modules["awkward"]
if isinstance(x, ak.Array):
return f"[{str(x.type)}]"

return type_name(x)
try:
match x:
case np.ndarray():
size = "x".join(str(dim) for dim in x.shape)
return f"[{size}]"

case list():
if x:
return f"list[{structured_repr(x[0])} x {len(x)}]"

else:
return "list[]"

case tuple():
if len(x) <= 3:
el_reprs = ", ".join(structured_repr(el) for el in x)
else:
el_reprs = ", ".join(structured_repr(el) for el in x[:3]) + ", ..."
return f"({el_reprs})"

case bool():
return str(x)

case Mapping() if (keys := list(x)) and isinstance(keys[0], str):
inner_repr = ", ".join(f"{k}: {structured_repr(x[k])}" for k in keys)
return f"{type_name(x)} {{{inner_repr}}}"

if "torch" in sys.modules:
torch = sys.modules["torch"]
if isinstance(x, torch.Tensor):
size = "x".join(str(dim) for dim in x.shape)
return f"[{size}]"

if "awkward" in sys.modules:
ak = sys.modules["awkward"]
if isinstance(x, ak.Array):
return f"[{str(x.type)}]"

return type_name(x)
except Exception as e:
logger.exception(e)
return "<error_cannot_repr>"
15 changes: 13 additions & 2 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import inspect
import sys
import uuid
import re
import builtins

from textwrap import dedent
from pathlib import Path
Expand Down Expand Up @@ -524,9 +526,16 @@ def recompile_fn_with_transform(
return Err(
f"No source code for probable built-in function {get_fn_name(source_fn)}"
)

try:
fn_ast = ast.parse(source)
except SyntaxError:
return Err(f"Syntax error in function {get_fn_name(source_fn)}")

# TODO: be smarter and look for global/nonlocal statements in the parsed repr instead
if "global " in source:
return Err("Cannot safely transform functions containing `global` declarations")

match fn_ast:
case ast.Module(body=[ast.FunctionDef() | ast.AsyncFunctionDef()]):
# Good
Expand Down Expand Up @@ -647,8 +656,9 @@ def extract_cell(cell):
try:
# TODO: this might be slow
scope = {
**scope_layers["core"],
**scope_layers["class_local"],
**vars(builtins),
**scope_layers["core"],
**scope_layers["module"],
**scope_layers["closure"],
**scope_layers["override"],
Expand Down Expand Up @@ -686,8 +696,9 @@ def extract_cell(cell):
if sourcefile in file_to_modules:
# Re-executing in a different module: re-declare scope without the previous module (otherwise we get incorrect behaviour like `min` being replaced with `np.amin` in `np.load`)
scope = {
**scope_layers["core"],
**scope_layers["class_local"],
**vars(builtins),
**scope_layers["core"],
**vars(file_to_modules[sourcefile]),
**scope_layers["closure"],
**scope_layers["override"],
Expand Down

0 comments on commit c9bc661

Please sign in to comment.