diff --git a/chirho/dynamical/handlers/trace.py b/chirho/dynamical/handlers/trace.py index a1076762..c6c3d3cc 100644 --- a/chirho/dynamical/handlers/trace.py +++ b/chirho/dynamical/handlers/trace.py @@ -58,10 +58,12 @@ def _pyro_post_simulate(self, msg) -> None: initial_state, timespan, ) - self.trace: Trajectory[T] = append(self.trace, trajectory[..., 1:-1]) - if len(self.trace) > len(self.logging_times): + idx = (timespan > timespan[0]) & (timespan < timespan[-1]) + if idx.any(): + self.trace: Trajectory[T] = append(self.trace, trajectory[idx]) + if idx.sum() > len(self.logging_times): raise ValueError( "Multiple simulates were used with a single DynamicTrace handler." "This is currently not supported." ) - msg["value"] = trajectory[..., -1].to_state() + msg["value"] = trajectory[timespan == timespan[-1]].to_state() diff --git a/chirho/dynamical/internals/_utils.py b/chirho/dynamical/internals/_utils.py index 5929027f..f02398a4 100644 --- a/chirho/dynamical/internals/_utils.py +++ b/chirho/dynamical/internals/_utils.py @@ -21,18 +21,6 @@ def _indices_of_state(state: State, *, event_dim: int = 0, **kwargs) -> IndexSet ) -@indices_of.register(Trajectory) -def _indices_of_trajectory( - trj: Trajectory, *, event_dim: int = 0, **kwargs -) -> IndexSet: - return union( - *( - indices_of(getattr(trj, k), event_dim=event_dim + 1, **kwargs) - for k in trj.keys - ) - ) - - @gather.register(State) def _gather_state( state: State[T], indices: IndexSet, *, event_dim: int = 0, **kwargs @@ -45,43 +33,6 @@ def _gather_state( ) -@gather.register(Trajectory) -def _gather_trajectory( - trj: Trajectory[T], indices: IndexSet, *, event_dim: int = 0, **kwargs -) -> Trajectory[T]: - return type(trj)( - **{ - k: gather(getattr(trj, k), indices, event_dim=event_dim + 1, **kwargs) - for k in trj.keys - } - ) - - -def _index_last_dim_with_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - # Index into the last dimension of x with a boolean mask. - # TODO AZ — There must be an easier way to do this? - # NOTE AZ — this could be easily modified to support the last n dimensions, adapt if needed. - - if mask.dtype != torch.bool: - raise ValueError( - f"_index_last_dim_with_mask only supports boolean mask indexing, but got dtype {mask.dtype}." - ) - - # Require that the mask is 1d and aligns with the last dimension of x. - if mask.ndim != 1 or mask.shape[0] != x.shape[-1]: - raise ValueError( - "_index_last_dim_with_mask only supports 1d boolean mask indexing, and must align with the last " - f"dimension of x, but got mask shape {mask.shape} and x shape {x.shape}." - ) - - return torch.masked_select( - x, - # Get a shape that will broadcast to the shape of x. This will be [1, ..., len(mask)]. - mask.reshape((1,) * (x.ndim - 1) + mask.shape) - # masked_select flattens tensors, so we need to reshape back to the original shape w/ the mask applied. - ).reshape(x.shape[:-1] + (int(mask.sum()),)) - - @intervene.register(State) def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]: new_state: State[T] = State() diff --git a/chirho/dynamical/internals/solver/torchdiffeq.py b/chirho/dynamical/internals/solver/torchdiffeq.py index 3714a115..416cab7c 100644 --- a/chirho/dynamical/internals/solver/torchdiffeq.py +++ b/chirho/dynamical/internals/solver/torchdiffeq.py @@ -117,7 +117,7 @@ def torchdiffeq_ode_simulate( trajectory = _torchdiffeq_ode_simulate_inner( dynamics, initial_state, timespan, **solver.odeint_kwargs ) - return trajectory[..., -1].to_state() + return trajectory[timespan == timespan[-1]].to_state() @simulate_trajectory.register(TorchDiffEq) diff --git a/chirho/dynamical/ops/dynamical.py b/chirho/dynamical/ops/dynamical.py index 90a65460..c15b6660 100644 --- a/chirho/dynamical/ops/dynamical.py +++ b/chirho/dynamical/ops/dynamical.py @@ -1,4 +1,3 @@ -import functools import numbers from typing import ( TYPE_CHECKING, @@ -14,6 +13,8 @@ import pyro import torch +from chirho.indexed.ops import IndexSet, gather, get_index_plates, indices_of + if TYPE_CHECKING: from chirho.dynamical.internals.backend import Solver @@ -55,55 +56,41 @@ def __getattr__(self, __name: str) -> T: class _Sliceable(Protocol[T_co]): - def __getitem__(self, key) -> Union[T_co, "_Sliceable[T_co]"]: + def __getitem__(self, key: torch.Tensor) -> Union[T_co, "_Sliceable[T_co]"]: + ... + + def squeeze(self, dim: int) -> "_Sliceable[T_co]": ... class Trajectory(Generic[T], State[_Sliceable[T]]): def __len__(self) -> int: - # TODO this implementation is just for tensors, but we should support other types. - return getattr(self, next(iter(self.keys))).shape[-1] - - def _getitem(self, key): - from chirho.dynamical.internals._utils import _index_last_dim_with_mask - - if isinstance(key, str): - raise ValueError( - "Trajectory does not support string indexing, use getattr instead if you want to access a specific " - "state variable." - ) - - item = State() if isinstance(key, int) else Trajectory() - for k, v in self.__dict__["_values"].items(): - if isinstance(key, torch.Tensor): - keyd_v = _index_last_dim_with_mask(v, key) - else: - keyd_v = v[key] - setattr(item, k, keyd_v) - return item - - # This is needed so that mypy and other type checkers believe that Trajectory can be indexed into. - @functools.singledispatchmethod - def __getitem__(self, key): - return self._getitem(key) - - @__getitem__.register(int) - def _getitem_int(self, key: int) -> State[T]: - return self._getitem(key) - - @__getitem__.register(torch.Tensor) - def _getitem_torchmask(self, key: torch.Tensor) -> "Trajectory[T]": - if key.dtype != torch.bool: - raise ValueError( - f"__getitem__ with a torch.Tensor only supports boolean mask indexing, but got dtype {key.dtype}." - ) - - return self._getitem(key) + if not self.keys: + return 0 + + name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()} + name_to_dim["__time"] = -1 + # TODO support event_dim > 0 + return len(indices_of(self, event_dim=0, name_to_dim=name_to_dim)["__time"]) + + def __getitem__(self, key: torch.Tensor) -> "Trajectory[T]": + assert key.dtype == torch.bool + + assert len(key.shape) == 1 and key.shape[0] > 1 # DEBUG + + if not key.any(): # DEBUG + return Trajectory() + + name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()} + name_to_dim["__time"] = -1 + idx = IndexSet(__time={i for i in range(key.shape[0]) if key[i]}) + # TODO support event_dim > 0 + return gather(self, idx, event_dim=0, name_to_dim=name_to_dim) def to_state(self) -> State[T]: ret: State[T] = State( # TODO support event_dim > 0 - **{k: getattr(self, k) for k in self.keys} + **{k: getattr(self, k).squeeze(-1) for k in self.keys} ) return ret