Skip to content

Commit

Permalink
Use rich tables for statespace build reports (#411)
Browse files Browse the repository at this point in the history
* Use rich table in build report

* Justify left-most column to left

* Re-run example notebooks

* Re-run example notebooks

* Refactor table initialization

* set requirement_table to None during initialization
  • Loading branch information
jessegrabowski authored Jan 10, 2025
1 parent c9134fe commit dcc353c
Show file tree
Hide file tree
Showing 6 changed files with 3,470 additions and 2,265 deletions.
1,243 changes: 751 additions & 492 deletions notebooks/Exponential Trend Smoothing.ipynb

Large diffs are not rendered by default.

578 changes: 424 additions & 154 deletions notebooks/Making a Custom Statespace Model.ipynb

Large diffs are not rendered by default.

1,620 changes: 922 additions & 698 deletions notebooks/SARMA Example.ipynb

Large diffs are not rendered by default.

1,546 changes: 966 additions & 580 deletions notebooks/Structural Timeseries Modeling.ipynb

Large diffs are not rendered by default.

656 changes: 350 additions & 306 deletions notebooks/VARMAX Example.ipynb

Large diffs are not rendered by default.

92 changes: 57 additions & 35 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from pymc.util import RandomState
from pytensor import Variable, graph_replace
from pytensor.compile import get_mode
from rich.box import SIMPLE_HEAD
from rich.console import Console
from rich.table import Table

from pymc_extras.statespace.core.representation import PytensorRepresentation
from pymc_extras.statespace.filters import (
Expand Down Expand Up @@ -254,53 +257,72 @@ def __init__(
self.kalman_smoother = KalmanSmoother()
self.make_symbolic_graph()

if verbose:
# These are split into separate try-except blocks, because it will be quite rare of models to implement
# _print_data_requirements, but we still want to print the prior requirements.
try:
self._print_prior_requirements()
except NotImplementedError:
pass
try:
self._print_data_requirements()
except NotImplementedError:
pass

def _print_prior_requirements(self) -> None:
"""
Prints a short report to the terminal about the priors needed for the model, including their names,
self.requirement_table = None
self._populate_prior_requirements()
self._populate_data_requirements()

if verbose and self.requirement_table:
console = Console()
console.print(self.requirement_table)

def _populate_prior_requirements(self) -> None:
"""
Add requirements about priors needed for the model to a rich table, including their names,
shapes, named dimensions, and any parameter constraints.
"""
out = ""
for param, info in self.param_info.items():
out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
out = out.rstrip()
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
# is not true.
try:
if not isinstance(self.param_info, dict):
return
except NotImplementedError:
return

_log.info(
"The following parameters should be assigned priors inside a PyMC "
f"model block: \n"
f"{out}"
)
if self.requirement_table is None:
self._initialize_requirement_table()

def _print_data_requirements(self) -> None:
for param, info in self.param_info.items():
self.requirement_table.add_row(
param, str(info["shape"]), info["constraints"], str(info["dims"])
)

def _populate_data_requirements(self) -> None:
"""
Prints a short report to the terminal about the data needed for the model, including their names, shapes,
and named dimensions.
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
"""
if not self.data_info:
try:
if not isinstance(self.data_info, dict):
return
except NotImplementedError:
return

out = ""
if self.requirement_table is None:
self._initialize_requirement_table()
else:
self.requirement_table.add_section()

for data, info in self.data_info.items():
out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
out = out.rstrip()
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))

def _initialize_requirement_table(self) -> None:
self.requirement_table = Table(
show_header=True,
show_edge=True,
box=SIMPLE_HEAD,
highlight=True,
)

_log.info(
"The following Data variables should be assigned to the model inside a PyMC "
f"model block: \n"
f"{out}"
self.requirement_table.title = "Model Requirements"
self.requirement_table.caption = (
"These parameters should be assigned priors inside a PyMC model block before "
"calling the build_statespace_graph method."
)

self.requirement_table.add_column("Variable", justify="left")
self.requirement_table.add_column("Shape", justify="left")
self.requirement_table.add_column("Constraints", justify="left")
self.requirement_table.add_column("Dimensions", justify="right")

def _unpack_statespace_with_placeholders(
self,
) -> tuple[
Expand Down

0 comments on commit dcc353c

Please sign in to comment.