Skip to content

Commit

Permalink
feat: composable inators
Browse files Browse the repository at this point in the history
- support construction from CLI arguments
- allow extending log capture mechanism
- remove `call_count` from FnContext
  • Loading branch information
blepabyte committed Oct 6, 2024
1 parent 8cc4c96 commit 6e8991d
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 666 deletions.
1 change: 0 additions & 1 deletion maxray/capture/__init__.py

This file was deleted.

196 changes: 0 additions & 196 deletions maxray/capture/logs.py

This file was deleted.

48 changes: 4 additions & 44 deletions maxray/inators/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,8 @@
from maxray.transforms import NodeContext
from maxray.nodes import NodeContext
from maxray.inators.core import S, Ray
from maxray.runner import RunCompleted, RunErrored, RunAborted, Break
from maxray.runner import ExecInfo, RunCompleted, RunErrored, RunAborted, Break

import ipdb
import rich
from rich.traceback import Traceback
from rich.live import Live
from rich.pretty import Pretty

import io
import sys
import builtins
from pathlib import Path
from functools import partial
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass

from typing import Any, Optional

import rerun as rr


class BaseInator:
Expand All @@ -31,12 +14,9 @@ def __init__(self):
self._name = type(self).__name__

def __repr__(self):
return self._name
return f"{self._name}()"

def __call__(self, x, ray: Ray):
if x is builtins.print:
x = partial(self.print, ctx=ray.ctx)

S.display.update_context(ray.ctx)

while True:
Expand Down Expand Up @@ -93,7 +73,7 @@ def _handle_reload(self):
S.display.render_traceback(e, e.__traceback__)

@contextmanager
def enter_session(self):
def enter_session(self, xi: ExecInfo):
try:
yield
finally:
Expand All @@ -110,23 +90,3 @@ def update_display_state(self, state: RunCompleted | RunAborted | RunErrored):
case RunErrored(exception=exception, traceback=traceback):
S.display.update_status("[red]Errored")
S.display.render_traceback(exception, traceback)

def print(self, *args, ctx, **kwargs):
if "file" in kwargs:
return print(*args, **kwargs)

print(*args, **kwargs, file=(buf := io.StringIO()))

source_location = (
Path(ctx.fn_context.source_file).name + ":" + str(ctx.location[0] + 1)
)
rr.log(
f"print/{source_location}",
rr.TextLog(buf.getvalue().strip(), level="TRACE"),
)

# Utility functions

def log(self, obj, level="INFO"):
rr.log("log", rr.TextLog(str(obj), level=level))
return obj
34 changes: 7 additions & 27 deletions maxray/inators/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@
from .display import Display

import ipdb
import attrs

import json
from contextvars import ContextVar
from dataclasses import dataclass
from pathlib import Path
import inspect
from functools import partial
from contextlib import contextmanager
from types import TracebackType
from typing import Any, Callable, Generator, Iterator

import rerun as rr
from loguru import logger


class Statefool:
Expand Down Expand Up @@ -246,36 +243,19 @@ def __getattr__(self, rewrite_cls_name: str) -> RewriteContext:
return self.by_class[rewrite_cls_name]


class LoggingEncoder(json.JSONEncoder):
def default(self, o):
if attrs.has(type(o)):
return attrs.asdict(o)
return super().default(o)


class Ray(RayContext):
"""
Captures the state of a point (syntax node) in the source code of the original program.
One instance is created for each point in the program, that is then passed to multiple handlers.
"""

def log(self, msg, *, level="INFO"):
"""
Logs to Rerun with the current context if active.
"""
match msg:
case dict():
try:
msg = json.dumps(msg, indent=2, cls=LoggingEncoder)
except Exception:
msg = str(msg)
case _ if attrs.has(type(msg)):
msg = str(msg)

location = Path(self.ctx.fn_context.source_file).name
line = self.ctx.location[0]
rr.log(f"log/{location}:{line}", rr.TextLog(msg, level=level))
def log(self, msg: str, *, level="INFO"):
logger.bind(
maxray_logged_from="ray",
source_file=self.ctx.fn_context.source_file,
source_line=self.ctx.location[0],
).log(level, str(msg))

def contextmanager(self, fn: Callable[[Ray], Iterator[Any]]):
return contextmanager(partial(fn, self))
Expand Down
Loading

0 comments on commit 6e8991d

Please sign in to comment.