diff --git a/maxray/__init__.py b/maxray/__init__.py index 1624151..718c194 100644 --- a/maxray/__init__.py +++ b/maxray/__init__.py @@ -1,6 +1,7 @@ from .transforms import recompile_fn_with_transform, NodeContext import inspect +from weakref import ref, WeakSet from contextvars import ContextVar from contextlib import contextmanager from dataclasses import dataclass @@ -101,7 +102,9 @@ def descend_allowed(x, ctx: NodeContext): _GLOBAL_WRITER_ACTIVE_FLAG = ContextVar("writer_active (global)", default=False) # We don't want to recompile the same function over and over - so our cache needs to be global +# TODO: add cache by source code location? _MAXRAY_FN_CACHE = dict() +_MAXRAY_FN_FAILED_CACHE = WeakSet() def callable_allowed_for_transform(x, ctx: NodeContext): @@ -179,6 +182,8 @@ def _maxray_walker_handler(x, ctx: NodeContext): # 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod elif callable_allowed_for_transform(x, ctx): # We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload + if x in _MAXRAY_FN_FAILED_CACHE: + pass if x in _MAXRAY_FN_CACHE: x = _MAXRAY_FN_CACHE[x] elif not descend_allowed(x, ctx): @@ -219,7 +224,7 @@ def _maxray_walker_handler(x, ctx: NodeContext): case Err(e): # Cache failures - _MAXRAY_FN_CACHE[x] = x + _MAXRAY_FN_FAILED_CACHE.add(x) # Errors in functions that have been recursively compiled are less important logger.warning(f"Failed to transform in walker handler: {e}") diff --git a/maxray/transforms.py b/maxray/transforms.py index a110506..b4220ac 100644 --- a/maxray/transforms.py +++ b/maxray/transforms.py @@ -509,6 +509,11 @@ def recompile_fn_with_transform( f"Non-existent source file ({sourcefile}) for function {get_fn_name(source_fn)}" ) + if module is None: + return Err( + f"Non-existent source module `{getattr(source_fn, '__module__', None)}` for function {get_fn_name(source_fn)}" + ) + fn_ast = ast.parse(source) except OSError: return Err(f"No source code for function {get_fn_name(source_fn)}")