Skip to content

Commit

Permalink
fix: memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Jun 17, 2024
1 parent 57ab881 commit eac440a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
7 changes: 6 additions & 1 deletion maxray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .transforms import recompile_fn_with_transform, NodeContext

import inspect
from weakref import ref, WeakSet
from contextvars import ContextVar
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -101,7 +102,9 @@ def descend_allowed(x, ctx: NodeContext):
_GLOBAL_WRITER_ACTIVE_FLAG = ContextVar("writer_active (global)", default=False)

# We don't want to recompile the same function over and over - so our cache needs to be global
# TODO: add cache by source code location?
_MAXRAY_FN_CACHE = dict()
_MAXRAY_FN_FAILED_CACHE = WeakSet()


def callable_allowed_for_transform(x, ctx: NodeContext):
Expand Down Expand Up @@ -179,6 +182,8 @@ def _maxray_walker_handler(x, ctx: NodeContext):
# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
elif callable_allowed_for_transform(x, ctx):
# We can only cache functions - as caching invokes __hash__, which may fail badly on incompletely-initialised class instances w/ __call__ methods, like torch._ops.OpOverload
if x in _MAXRAY_FN_FAILED_CACHE:
pass
if x in _MAXRAY_FN_CACHE:
x = _MAXRAY_FN_CACHE[x]
elif not descend_allowed(x, ctx):
Expand Down Expand Up @@ -219,7 +224,7 @@ def _maxray_walker_handler(x, ctx: NodeContext):

case Err(e):
# Cache failures
_MAXRAY_FN_CACHE[x] = x
_MAXRAY_FN_FAILED_CACHE.add(x)
# Errors in functions that have been recursively compiled are less important
logger.warning(f"Failed to transform in walker handler: {e}")

Expand Down
5 changes: 5 additions & 0 deletions maxray/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,11 @@ def recompile_fn_with_transform(
f"Non-existent source file ({sourcefile}) for function {get_fn_name(source_fn)}"
)

if module is None:
return Err(
f"Non-existent source module `{getattr(source_fn, '__module__', None)}` for function {get_fn_name(source_fn)}"
)

fn_ast = ast.parse(source)
except OSError:
return Err(f"No source code for function {get_fn_name(source_fn)}")
Expand Down

0 comments on commit eac440a

Please sign in to comment.