Skip to content

Commit

Permalink
Added docstring for intervene following issue #113 (#182)
Browse files Browse the repository at this point in the history
* added docstring for intervene following issue #113

* removed extra spacing

* adds docstring for DoMessenger

* removes generic type hint on _intervene_atom so that intervene signatures properly show up in docs.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Andy Zane <[email protected]>
  • Loading branch information
3 people authored Jul 6, 2023
1 parent 2ec8d7b commit 6bf6352
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
11 changes: 10 additions & 1 deletion causal_pyro/interventional/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@intervene.register(torch.Tensor)
@pyro.poutine.runtime.effectful(type="intervene")
def _intervene_atom(
obs: T, act: Optional[AtomicIntervention[T]] = None, *, event_dim: int = 0, **kwargs
obs, act: Optional[AtomicIntervention[T]] = None, *, event_dim: int = 0, **kwargs
) -> T:
"""
Intervene on an atomic value in a probabilistic program.
Expand Down Expand Up @@ -76,9 +76,18 @@ def _intervene_callable_wrapper(*args, **kwargs):
class DoMessenger(Generic[T], pyro.poutine.messenger.Messenger):
"""
Intervene on values in a probabilistic program.
:class:`DoMessenger` is an effect handler that intervenes at specified sample sites
in a probabilistic program. This allows users to define programs without any
interventional or causal semantics, and then to add those features later in the
context of, for example, :class:`DoMessenger`. This handler uses :func:`intervene`
internally and supports the same types of interventions.
"""

def __init__(self, actions: Mapping[Hashable, AtomicIntervention[T]]):
"""
:param actions: A mapping from names of sample sites to interventions.
"""
self.actions = actions
super().__init__()

Expand Down
9 changes: 9 additions & 0 deletions causal_pyro/interventional/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,14 @@
def intervene(obs, act: Optional[Intervention[T]] = None, **kwargs):
"""
Intervene on a value in a probabilistic program.
:func:`intervene` is primarily used internally in :class:`DoMessenger`
for concisely and extensibly defining the semantics of interventions. This
function is generically typed and extensible to new types via
:func:`functools.singledispatch`. When its first argument is a function,
:func:`intervene` now behaves like the current `causal_pyro.query.do` effect handler.
:param obs: a value in a probabilistic program.
:param act: an optional intervention.
"""
raise NotImplementedError(f"intervene not implemented for type {type(obs)}")

0 comments on commit 6bf6352

Please sign in to comment.