-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: inator template for source overlays
- Loading branch information
Showing
2 changed files
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import inspect | ||
|
||
|
||
class Statefool: | ||
def __init__(self): | ||
self._existing_keys = {} | ||
|
||
def __getitem__(self, key): | ||
return self._existing_keys[key][1] | ||
|
||
def __setitem__(self, key, value): | ||
v, _old_value = self._existing_keys[key] | ||
self._existing_keys[key] = (v, value) | ||
|
||
def define_once(self, key, factory, /, v: int = 0): | ||
""" | ||
Args: | ||
key (Immutable + Hash + Eq): Identifies this object between steps | ||
""" | ||
# Note: can use runtime value for def key for auto grouping! | ||
ex_o = None | ||
if key in self._existing_keys: | ||
ex_v, ex_o = self._existing_keys[key] | ||
if ex_v >= v: | ||
return ex_o | ||
|
||
self._existing_keys[key] = (v, new_o := factory(ex_o)) | ||
return new_o | ||
|
||
|
||
class Matcher: | ||
""" | ||
Pattern-matching helper to provide a stable interface over source information stored in NodeContext.props (collected during AST parse) | ||
""" | ||
|
||
def __init__(self, x, ctx): | ||
self.x = x | ||
self.unpacked_x = x | ||
self.ctx = ctx | ||
|
||
match ctx.props: | ||
case {"assigned": {"targets": targets}}: | ||
if len(targets) > 1: | ||
if inspect.isgenerator(x) or isinstance(x, (map, filter)): | ||
# Greedily consume iterators before assignment | ||
self.unpacked_x = tuple(iter(x)) | ||
else: | ||
# Otherwise for chained equality like a = b, c = it, code relies on `a` being of the original type | ||
self.unpacked_x = x | ||
# TODO: doesn't work for starred assignments: x, *y, z = iterable | ||
self._assigned = { | ||
target: val for target, val in zip(targets, self.unpacked_x) | ||
} | ||
elif len(targets) == 1: | ||
self._assigned = {targets[0]: x} | ||
else: | ||
self._assigned = {} | ||
case _: | ||
self._assigned = {} | ||
|
||
def __iter__(self): | ||
yield self.x | ||
yield self.ctx | ||
|
||
def assigned(self): | ||
return self._assigned | ||
|
||
def unpacked(self): | ||
return self.unpacked_x | ||
|
||
|
||
S = Statefool() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from maxray.inators.core import S, Matcher | ||
from maxray.runner import ( | ||
MAIN_FN_NAME, | ||
RunCompleted, | ||
RunErrored, | ||
AbortRun, | ||
RestartRun, | ||
Break, | ||
) | ||
from maxray.runner import InteractiveContext | ||
|
||
import pandas as pd | ||
import numpy as np | ||
|
||
import io | ||
import time | ||
from uuid import uuid4 | ||
from pathlib import Path | ||
from functools import partial | ||
|
||
import rerun as rr | ||
|
||
|
||
class Inator: | ||
def __init__(self): | ||
self.session_name = f"maxray:{type(self).__name__}" | ||
self.match_assignments = True | ||
self.last_display_tick = time.perf_counter() | ||
|
||
def log(self, obj, level="INFO"): | ||
rr.log("log", rr.TextLog(str(obj), level=level)) | ||
return obj | ||
|
||
def print(self, *args, ctx, **kwargs): | ||
if "file" in kwargs: | ||
return print(*args, **kwargs) | ||
|
||
print(*args, **kwargs, file=(buf := io.StringIO())) | ||
|
||
source_location = ( | ||
Path(ctx.fn_context.source_file).name + ":" + str(ctx.location[0] + 1) | ||
) | ||
rr.log( | ||
f"print/{source_location}", | ||
rr.TextLog(buf.getvalue().strip(), level="TRACE"), | ||
) | ||
|
||
def __call__(self, x, ctx: InteractiveContext): | ||
S.define_once( | ||
"RERUN_INSTANCE", | ||
lambda _: rr.init(self.session_name, spawn=True, recording_id=str(uuid4())), | ||
v=1, | ||
) | ||
|
||
if x is print: | ||
return partial(self.print, ctx=ctx.copy()) | ||
|
||
# Randomly ticks as progress indicator | ||
if (tick := time.perf_counter()) - self.last_display_tick > 0.1: | ||
ctx.display() | ||
self.last_display_tick = tick | ||
|
||
match x: | ||
# Drop into debugger on unhandled error | ||
case RunErrored(): | ||
import ipdb | ||
import rich | ||
from rich.traceback import Traceback | ||
|
||
ctx.live.stop() | ||
rich.print(Traceback(x.exception_trace)) | ||
|
||
ipdb.post_mortem() | ||
|
||
ctx.live.start() | ||
|
||
# Manual control flow | ||
|
||
# exit() | ||
# raise Break() | ||
# raise AbortRun() | ||
# raise RestartRun() | ||
|
||
# ctx.clear() | ||
|
||
# Global source code overlays | ||
match ctx.source: | ||
case "...": | ||
... | ||
|
||
# Bind local variables in stack frames we're interested in | ||
match ctx.local_scope: | ||
case {} if ctx.fn_context.name == MAIN_FN_NAME: | ||
... | ||
case _: | ||
return x | ||
|
||
if self.match_assignments: | ||
# Parse variable assignments | ||
M = Matcher(x, ctx) | ||
match M.assigned(): | ||
case {"df": loss}: | ||
ctx.track(loss=loss) | ||
|
||
return M.unpacked() | ||
|
||
return x |