Skip to content

Commit

Permalink
feat: inator template for source overlays
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Sep 1, 2024
1 parent cbddfd5 commit 2af8805
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
72 changes: 72 additions & 0 deletions maxray/inators/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import inspect


class Statefool:
def __init__(self):
self._existing_keys = {}

def __getitem__(self, key):
return self._existing_keys[key][1]

def __setitem__(self, key, value):
v, _old_value = self._existing_keys[key]
self._existing_keys[key] = (v, value)

def define_once(self, key, factory, /, v: int = 0):
"""
Args:
key (Immutable + Hash + Eq): Identifies this object between steps
"""
# Note: can use runtime value for def key for auto grouping!
ex_o = None
if key in self._existing_keys:
ex_v, ex_o = self._existing_keys[key]
if ex_v >= v:
return ex_o

self._existing_keys[key] = (v, new_o := factory(ex_o))
return new_o


class Matcher:
"""
Pattern-matching helper to provide a stable interface over source information stored in NodeContext.props (collected during AST parse)
"""

def __init__(self, x, ctx):
self.x = x
self.unpacked_x = x
self.ctx = ctx

match ctx.props:
case {"assigned": {"targets": targets}}:
if len(targets) > 1:
if inspect.isgenerator(x) or isinstance(x, (map, filter)):
# Greedily consume iterators before assignment
self.unpacked_x = tuple(iter(x))
else:
# Otherwise for chained equality like a = b, c = it, code relies on `a` being of the original type
self.unpacked_x = x
# TODO: doesn't work for starred assignments: x, *y, z = iterable
self._assigned = {
target: val for target, val in zip(targets, self.unpacked_x)
}
elif len(targets) == 1:
self._assigned = {targets[0]: x}
else:
self._assigned = {}
case _:
self._assigned = {}

def __iter__(self):
yield self.x
yield self.ctx

def assigned(self):
return self._assigned

def unpacked(self):
return self.unpacked_x


S = Statefool()
107 changes: 107 additions & 0 deletions maxray/inators/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from maxray.inators.core import S, Matcher
from maxray.runner import (
MAIN_FN_NAME,
RunCompleted,
RunErrored,
AbortRun,
RestartRun,
Break,
)
from maxray.runner import InteractiveContext

import pandas as pd
import numpy as np

import io
import time
from uuid import uuid4
from pathlib import Path
from functools import partial

import rerun as rr


class Inator:
def __init__(self):
self.session_name = f"maxray:{type(self).__name__}"
self.match_assignments = True
self.last_display_tick = time.perf_counter()

def log(self, obj, level="INFO"):
rr.log("log", rr.TextLog(str(obj), level=level))
return obj

def print(self, *args, ctx, **kwargs):
if "file" in kwargs:
return print(*args, **kwargs)

print(*args, **kwargs, file=(buf := io.StringIO()))

source_location = (
Path(ctx.fn_context.source_file).name + ":" + str(ctx.location[0] + 1)
)
rr.log(
f"print/{source_location}",
rr.TextLog(buf.getvalue().strip(), level="TRACE"),
)

def __call__(self, x, ctx: InteractiveContext):
S.define_once(
"RERUN_INSTANCE",
lambda _: rr.init(self.session_name, spawn=True, recording_id=str(uuid4())),
v=1,
)

if x is print:
return partial(self.print, ctx=ctx.copy())

# Randomly ticks as progress indicator
if (tick := time.perf_counter()) - self.last_display_tick > 0.1:
ctx.display()
self.last_display_tick = tick

match x:
# Drop into debugger on unhandled error
case RunErrored():
import ipdb
import rich
from rich.traceback import Traceback

ctx.live.stop()
rich.print(Traceback(x.exception_trace))

ipdb.post_mortem()

ctx.live.start()

# Manual control flow

# exit()
# raise Break()
# raise AbortRun()
# raise RestartRun()

# ctx.clear()

# Global source code overlays
match ctx.source:
case "...":
...

# Bind local variables in stack frames we're interested in
match ctx.local_scope:
case {} if ctx.fn_context.name == MAIN_FN_NAME:
...
case _:
return x

if self.match_assignments:
# Parse variable assignments
M = Matcher(x, ctx)
match M.assigned():
case {"df": loss}:
ctx.track(loss=loss)

return M.unpacked()

return x

0 comments on commit 2af8805

Please sign in to comment.