Skip to content

Commit

Permalink
refactor: separate out rich/display from core runner impl
Browse files Browse the repository at this point in the history
  • Loading branch information
blepabyte committed Sep 11, 2024
1 parent 3433759 commit cc795cf
Show file tree
Hide file tree
Showing 8 changed files with 558 additions and 332 deletions.
42 changes: 42 additions & 0 deletions maxray/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import re
from importlib.util import find_spec
from pathlib import Path

import click


@click.group()
def cli():
pass


@cli.command()
@click.argument("file", type=str)
@click.option("-f", "--force", is_flag=True)
@click.option("--runner", is_flag=True)
def template(file: str, force: bool, runner: bool):
path = Path(file).resolve(True)
assert path.suffix == ".py"

spec = find_spec("maxray.inators.template")
assert spec is not None
assert spec.origin is not None

template_path = path.with_name(f"over_{path.name}")
if not force:
assert not template_path.exists(), f"{template_path} exists!"

source = Path(spec.origin).read_text()
if not runner:
source = re.sub(r"\s+def runner(?:\n|.*)+", "", source, flags=re.MULTILINE)
source += "\n"
template_path.write_text(source)
print(f"Wrote template to {template_path}")


def main():
cli.main()


if __name__ == "__main__":
main()
189 changes: 189 additions & 0 deletions maxray/inators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from maxray.inators.core import S, Matcher
from maxray.inators.display import Display
from maxray.runner import RunCompleted, RunErrored, RunAborted, Break

import ipdb
import rich
from rich.traceback import Traceback
from rich.live import Live
from rich.pretty import Pretty

import io
import sys
import builtins
from pathlib import Path
from functools import partial
from contextlib import contextmanager
from dataclasses import dataclass

from typing import Any, Optional

import rerun as rr


@dataclass
class LocalContext:
x: Any
ctx: Any
matcher: Optional[Matcher] = None


class BaseInator:
def __init__(
self, name: str, rerun: bool, auto_debug: bool, match_assignments: bool
):
"""
Args:
- name (str): Descriptive name for the program. Only used for visualization and logging.
- rerun (bool): Whether to auto-init the Rerun visualization library.
- match_assignments (bool): Enables proper matching on `self.assigned()` by unpacking multiple assignments like `a, b = x`, handling cases where `x` is a stateful iterator or generator by consuming it (converting it to a tuple, throwing away any type information).
"""
self._name = name
self._debugger = auto_debug
self._match_assignments = match_assignments
self._display = S.define_once("RICH_LIVE_DISPLAY", lambda _: Display())

if rerun:
rr.init(self._name, spawn=True)

def __call__(self, x, ctx):
if x is builtins.print:
x = partial(self.print, ctx=ctx)

if self._match_assignments:
M = Matcher(x, ctx)
x = M.unpacked()
lctx = LocalContext(x, ctx, M)
else:
lctx = LocalContext(x, ctx)

self._last_ctx = lctx
self._display.update_context(ctx)

while True:
try:
self.xray(x, ctx)
x = self.maxray(x, ctx)
break
except Break:
self._display.update_status("[violet]PAUSED")
self.wait_and_reload()
except Exception as e:
# Capture and show traceback
self._display.update_status("[violet]PAUSED")
self._display.render_traceback(e, e.__traceback__)
self.wait_and_reload()

match x:
case RunCompleted() | RunAborted() | RunErrored():
self.update_display_state(x)
self._display.render()
case _:
self._display.update_status("[yellow]Running...")
self._display.render_maybe()
return x

def xray(self, x, ctx):
"""
Override to implement equivalent of @xray
"""
pass

def maxray(self, x, ctx):
"""
Override to implement equivalent of @maxray
"""
return x

def runner(self):
raise NotImplementedError()

def wait_and_reload(self):
# Patched in at runtime
raise NotImplementedError()

@contextmanager
def _handle_reload(self):
"""
Provides control over what happens if an error is encountered while reloading itself.
"""
try:
yield
except Exception as e:
# Capture and show traceback
self._display.update_status("[violet]PAUSED")
self._display.render_traceback(e, e.__traceback__)

@property
def display(self) -> Display:
return self._display

@contextmanager
def enter_session(self):
try:
yield
finally:
self.display.live.stop()

def update_display_state(self, state: RunCompleted | RunAborted | RunErrored):
match state:
case RunCompleted():
self._display.update_status("[green]Completed")
case RunAborted(exception=exception):
self._display.update_status(
f"[cyan]Aborted ({type(exception).__name__})"
)
case RunErrored(exception=exception, traceback=traceback):
self._display.update_status("[red]Errored")
self._display.render_traceback(exception, traceback)

def print(self, *args, ctx, **kwargs):
if "file" in kwargs:
return print(*args, **kwargs)

print(*args, **kwargs, file=(buf := io.StringIO()))

source_location = (
Path(ctx.fn_context.source_file).name + ":" + str(ctx.location[0] + 1)
)
rr.log(
f"print/{source_location}",
rr.TextLog(buf.getvalue().strip(), level="TRACE"),
)

# Utility functions

def log(self, obj, level="INFO"):
rr.log("log", rr.TextLog(str(obj), level=level))
return obj

def enter_debugger(self, post_mortem: RunErrored | bool = False):
with self.display.hidden():
if post_mortem is True:
# Needs to be an active exception
ipdb.post_mortem()
elif isinstance(post_mortem, RunErrored):
exc_trace = Traceback.extract(
type(post_mortem.exception),
post_mortem.exception,
post_mortem.traceback,
show_locals=True,
)
traceback = Traceback(
exc_trace,
suppress=[sys.modules["maxray"]],
show_locals=True,
max_frames=5,
)
rich.print(traceback)
ipdb.post_mortem(post_mortem.traceback)
else:
ipdb.set_trace()

def assigned(self):
if self._last_ctx is None:
raise RuntimeError("Outside of any node context")
if self._last_ctx.matcher is None:
raise ValueError("Must enable match_assignments to use `assigned`")

return self._last_ctx.matcher.assigned()
1 change: 1 addition & 0 deletions maxray/inators/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def define_once(self, key, factory, /, v: int = 0):
"""
# Note: can use runtime value for def key for auto grouping!
ex_o = None
# Not thread-safe....
if key in self._existing_keys:
ex_v, ex_o = self._existing_keys[key]
if ex_v >= v:
Expand Down
Loading

0 comments on commit cc795cf

Please sign in to comment.