Skip to content

Commit

Permalink
added scatter for dict and intervention function on dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Dec 5, 2024
1 parent f015a77 commit 4794011
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
29 changes: 27 additions & 2 deletions chirho/indexed/internals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Dict, Hashable, Optional, TypeVar, Union
from typing import Any, Dict, Hashable, Optional, TypeVar, Union

import pyro
import pyro.infer.reparam
Expand Down Expand Up @@ -67,7 +67,7 @@ def _gather_tensor(


@gather.register(dict)
def _gather_state(
def _gather_dict(
value: Dict[K, T], indices: IndexSet, *, event_dim: int = 0, **kwargs
) -> Dict[K, T]:
return {
Expand Down Expand Up @@ -143,6 +143,31 @@ def _scatter_tensor(
return result


@scatter.register(dict)
def _scatter_dict(
value: Dict[K, T],
indexset: IndexSet,
*,
result: Optional[Dict[K, Optional[T]]] = None,
event_dim: Optional[int] = None,
name_to_dim: Optional[Dict[Hashable, int]] = None,
) -> Dict[K, Any]:

if result is None:
result = {k: None for k in value.keys()}

for k in value.keys():
result[k] = scatter(
value[k],
indexset,
result=result[k],
event_dim=event_dim,
name_to_dim=name_to_dim,
)

return result


@indices_of.register
def _indices_of_number(value: numbers.Number, **kwargs) -> IndexSet:
return IndexSet()
Expand Down
17 changes: 15 additions & 2 deletions chirho/interventional/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import collections
import functools
from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar
from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar, Union

import pyro
import torch
Expand Down Expand Up @@ -60,14 +60,27 @@ def _intervene_atom_distribution(

@intervene.register(dict)
def _dict_intervene(
obs: Dict[K, T], act: Dict[K, AtomicIntervention[T]], **kwargs
obs: Dict[K, T],
act: Union[Dict[K, AtomicIntervention[T]], Callable[[Dict[K, T]], Dict[K, T]]],
**kwargs,
) -> Dict[K, T]:

if callable(act):
return _dict_intervene_callable(obs, act, **kwargs)

result: Dict[K, T] = {}
for k in obs.keys():
result[k] = intervene(obs[k], act[k] if k in act else None, **kwargs)
return result


@pyro.poutine.runtime.effectful(type="intervene")
def _dict_intervene_callable(
obs: Dict[K, T], act: Callable[[Dict[K, T]], Dict[K, T]], **kwargs
) -> Dict[K, T]:
return act(obs)


@intervene.register
def _intervene_callable(
obs: collections.abc.Callable,
Expand Down
1 change: 1 addition & 0 deletions tests/dynamical/test_static_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
dict(I=torch.tensor(50.0)),
dict(S=torch.tensor(50.0), R=torch.tensor(50.0)),
dict(S=torch.tensor(50.0), I=torch.tensor(50.0), R=torch.tensor(50.0)),
lambda X: {k: v / 2 for k, v in X.items()},
]

# Define intervention times before all tspan values.
Expand Down

0 comments on commit 4794011

Please sign in to comment.