Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Trajectory methods use indexed ops #317

Merged
merged 8 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions chirho/dynamical/handlers/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ def _pyro_post_simulate(self, msg) -> None:
initial_state,
timespan,
)
self.trace: Trajectory[T] = append(self.trace, trajectory[..., 1:-1])
if len(self.trace) > len(self.logging_times):
idx = (timespan > timespan[0]) & (timespan < timespan[-1])
if idx.any():
self.trace: Trajectory[T] = append(self.trace, trajectory[idx])
if idx.sum() > len(self.logging_times):
raise ValueError(
"Multiple simulates were used with a single DynamicTrace handler."
"This is currently not supported."
)
msg["value"] = trajectory[..., -1].to_state()
msg["value"] = trajectory[timespan == timespan[-1]].to_state()
49 changes: 0 additions & 49 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,6 @@ def _indices_of_state(state: State, *, event_dim: int = 0, **kwargs) -> IndexSet
)


@indices_of.register(Trajectory)
def _indices_of_trajectory(
trj: Trajectory, *, event_dim: int = 0, **kwargs
) -> IndexSet:
return union(
*(
indices_of(getattr(trj, k), event_dim=event_dim + 1, **kwargs)
for k in trj.keys
)
)


@gather.register(State)
def _gather_state(
state: State[T], indices: IndexSet, *, event_dim: int = 0, **kwargs
Expand All @@ -45,43 +33,6 @@ def _gather_state(
)


@gather.register(Trajectory)
def _gather_trajectory(
trj: Trajectory[T], indices: IndexSet, *, event_dim: int = 0, **kwargs
) -> Trajectory[T]:
return type(trj)(
**{
k: gather(getattr(trj, k), indices, event_dim=event_dim + 1, **kwargs)
for k in trj.keys
}
)


def _index_last_dim_with_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# Index into the last dimension of x with a boolean mask.
# TODO AZ — There must be an easier way to do this?
# NOTE AZ — this could be easily modified to support the last n dimensions, adapt if needed.

if mask.dtype != torch.bool:
raise ValueError(
f"_index_last_dim_with_mask only supports boolean mask indexing, but got dtype {mask.dtype}."
)

# Require that the mask is 1d and aligns with the last dimension of x.
if mask.ndim != 1 or mask.shape[0] != x.shape[-1]:
raise ValueError(
"_index_last_dim_with_mask only supports 1d boolean mask indexing, and must align with the last "
f"dimension of x, but got mask shape {mask.shape} and x shape {x.shape}."
)

return torch.masked_select(
x,
# Get a shape that will broadcast to the shape of x. This will be [1, ..., len(mask)].
mask.reshape((1,) * (x.ndim - 1) + mask.shape)
# masked_select flattens tensors, so we need to reshape back to the original shape w/ the mask applied.
).reshape(x.shape[:-1] + (int(mask.sum()),))


@intervene.register(State)
def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]:
new_state: State[T] = State()
Expand Down
2 changes: 1 addition & 1 deletion chirho/dynamical/internals/solver/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def torchdiffeq_ode_simulate(
trajectory = _torchdiffeq_ode_simulate_inner(
dynamics, initial_state, timespan, **solver.odeint_kwargs
)
return trajectory[..., -1].to_state()
return trajectory[timespan == timespan[-1]].to_state()


@simulate_trajectory.register(TorchDiffEq)
Expand Down
69 changes: 28 additions & 41 deletions chirho/dynamical/ops/dynamical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import numbers
from typing import (
TYPE_CHECKING,
Expand All @@ -14,6 +13,8 @@
import pyro
import torch

from chirho.indexed.ops import IndexSet, gather, get_index_plates, indices_of

if TYPE_CHECKING:
from chirho.dynamical.internals.backend import Solver

Expand Down Expand Up @@ -55,55 +56,41 @@ def __getattr__(self, __name: str) -> T:


class _Sliceable(Protocol[T_co]):
def __getitem__(self, key) -> Union[T_co, "_Sliceable[T_co]"]:
def __getitem__(self, key: torch.Tensor) -> Union[T_co, "_Sliceable[T_co]"]:
...

def squeeze(self, dim: int) -> "_Sliceable[T_co]":
...


class Trajectory(Generic[T], State[_Sliceable[T]]):
def __len__(self) -> int:
# TODO this implementation is just for tensors, but we should support other types.
return getattr(self, next(iter(self.keys))).shape[-1]

def _getitem(self, key):
from chirho.dynamical.internals._utils import _index_last_dim_with_mask

if isinstance(key, str):
raise ValueError(
"Trajectory does not support string indexing, use getattr instead if you want to access a specific "
"state variable."
)

item = State() if isinstance(key, int) else Trajectory()
for k, v in self.__dict__["_values"].items():
if isinstance(key, torch.Tensor):
keyd_v = _index_last_dim_with_mask(v, key)
else:
keyd_v = v[key]
setattr(item, k, keyd_v)
return item

# This is needed so that mypy and other type checkers believe that Trajectory can be indexed into.
@functools.singledispatchmethod
def __getitem__(self, key):
return self._getitem(key)

@__getitem__.register(int)
def _getitem_int(self, key: int) -> State[T]:
return self._getitem(key)

@__getitem__.register(torch.Tensor)
def _getitem_torchmask(self, key: torch.Tensor) -> "Trajectory[T]":
if key.dtype != torch.bool:
raise ValueError(
f"__getitem__ with a torch.Tensor only supports boolean mask indexing, but got dtype {key.dtype}."
)

return self._getitem(key)
if not self.keys:
return 0

name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim["__time"] = -1
# TODO support event_dim > 0
return len(indices_of(self, event_dim=0, name_to_dim=name_to_dim)["__time"])

def __getitem__(self, key: torch.Tensor) -> "Trajectory[T]":
assert key.dtype == torch.bool
SamWitty marked this conversation as resolved.
Show resolved Hide resolved

assert len(key.shape) == 1 and key.shape[0] > 1 # DEBUG

if not key.any(): # DEBUG
return Trajectory()

name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim["__time"] = -1
idx = IndexSet(__time={i for i in range(key.shape[0]) if key[i]})
# TODO support event_dim > 0
return gather(self, idx, event_dim=0, name_to_dim=name_to_dim)
SamWitty marked this conversation as resolved.
Show resolved Hide resolved

def to_state(self) -> State[T]:
ret: State[T] = State(
# TODO support event_dim > 0
**{k: getattr(self, k) for k in self.keys}
**{k: getattr(self, k).squeeze(-1) for k in self.keys}
)
return ret

Expand Down