From 2af880557cb4ea652c1e34c524251d860bdd2bf6 Mon Sep 17 00:00:00 2001 From: blepabyte <255@blepabyte.me> Date: Mon, 2 Sep 2024 01:19:01 +1200 Subject: [PATCH] feat: inator template for source overlays --- maxray/inators/core.py | 72 +++++++++++++++++++++++++ maxray/inators/template.py | 107 +++++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 maxray/inators/core.py create mode 100644 maxray/inators/template.py diff --git a/maxray/inators/core.py b/maxray/inators/core.py new file mode 100644 index 0000000..f634a99 --- /dev/null +++ b/maxray/inators/core.py @@ -0,0 +1,72 @@ +import inspect + + +class Statefool: + def __init__(self): + self._existing_keys = {} + + def __getitem__(self, key): + return self._existing_keys[key][1] + + def __setitem__(self, key, value): + v, _old_value = self._existing_keys[key] + self._existing_keys[key] = (v, value) + + def define_once(self, key, factory, /, v: int = 0): + """ + Args: + key (Immutable + Hash + Eq): Identifies this object between steps + """ + # Note: can use runtime value for def key for auto grouping! + ex_o = None + if key in self._existing_keys: + ex_v, ex_o = self._existing_keys[key] + if ex_v >= v: + return ex_o + + self._existing_keys[key] = (v, new_o := factory(ex_o)) + return new_o + + +class Matcher: + """ + Pattern-matching helper to provide a stable interface over source information stored in NodeContext.props (collected during AST parse) + """ + + def __init__(self, x, ctx): + self.x = x + self.unpacked_x = x + self.ctx = ctx + + match ctx.props: + case {"assigned": {"targets": targets}}: + if len(targets) > 1: + if inspect.isgenerator(x) or isinstance(x, (map, filter)): + # Greedily consume iterators before assignment + self.unpacked_x = tuple(iter(x)) + else: + # Otherwise for chained equality like a = b, c = it, code relies on `a` being of the original type + self.unpacked_x = x + # TODO: doesn't work for starred assignments: x, *y, z = iterable + self._assigned = { + target: val for target, val in zip(targets, self.unpacked_x) + } + elif len(targets) == 1: + self._assigned = {targets[0]: x} + else: + self._assigned = {} + case _: + self._assigned = {} + + def __iter__(self): + yield self.x + yield self.ctx + + def assigned(self): + return self._assigned + + def unpacked(self): + return self.unpacked_x + + +S = Statefool() diff --git a/maxray/inators/template.py b/maxray/inators/template.py new file mode 100644 index 0000000..97b9d0d --- /dev/null +++ b/maxray/inators/template.py @@ -0,0 +1,107 @@ +from maxray.inators.core import S, Matcher +from maxray.runner import ( + MAIN_FN_NAME, + RunCompleted, + RunErrored, + AbortRun, + RestartRun, + Break, +) +from maxray.runner import InteractiveContext + +import pandas as pd +import numpy as np + +import io +import time +from uuid import uuid4 +from pathlib import Path +from functools import partial + +import rerun as rr + + +class Inator: + def __init__(self): + self.session_name = f"maxray:{type(self).__name__}" + self.match_assignments = True + self.last_display_tick = time.perf_counter() + + def log(self, obj, level="INFO"): + rr.log("log", rr.TextLog(str(obj), level=level)) + return obj + + def print(self, *args, ctx, **kwargs): + if "file" in kwargs: + return print(*args, **kwargs) + + print(*args, **kwargs, file=(buf := io.StringIO())) + + source_location = ( + Path(ctx.fn_context.source_file).name + ":" + str(ctx.location[0] + 1) + ) + rr.log( + f"print/{source_location}", + rr.TextLog(buf.getvalue().strip(), level="TRACE"), + ) + + def __call__(self, x, ctx: InteractiveContext): + S.define_once( + "RERUN_INSTANCE", + lambda _: rr.init(self.session_name, spawn=True, recording_id=str(uuid4())), + v=1, + ) + + if x is print: + return partial(self.print, ctx=ctx.copy()) + + # Randomly ticks as progress indicator + if (tick := time.perf_counter()) - self.last_display_tick > 0.1: + ctx.display() + self.last_display_tick = tick + + match x: + # Drop into debugger on unhandled error + case RunErrored(): + import ipdb + import rich + from rich.traceback import Traceback + + ctx.live.stop() + rich.print(Traceback(x.exception_trace)) + + ipdb.post_mortem() + + ctx.live.start() + + # Manual control flow + + # exit() + # raise Break() + # raise AbortRun() + # raise RestartRun() + + # ctx.clear() + + # Global source code overlays + match ctx.source: + case "...": + ... + + # Bind local variables in stack frames we're interested in + match ctx.local_scope: + case {} if ctx.fn_context.name == MAIN_FN_NAME: + ... + case _: + return x + + if self.match_assignments: + # Parse variable assignments + M = Matcher(x, ctx) + match M.assigned(): + case {"df": loss}: + ctx.track(loss=loss) + + return M.unpacked() + + return x