Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix for timeout in graph_model #3460

Merged
merged 4 commits into from
May 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Used `numpy.vectorize` in `distributions.distribution._compile_theano_function`. This enables `sample_prior_predictive` and `sample_posterior_predictive` to ask for tuples of samples instead of just integers. This fixes issue #3422.

### Maintenance
- Fixed an issue in `model_graph` that caused construction of the graph of the model for rendering to hang: replaced a search over the powerset of the nodes with a breadth-first search over the nodes. Fix for #3458.
- All occurances of `sd` as a parameter name have been renamed to `sigma`. `sd` will continue to function for backwards compatibility.
- Made `BrokenPipeError` for parallel sampling more verbose on Windows.
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
Expand Down
57 changes: 34 additions & 23 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import itertools
from collections import deque
from typing import Iterator, Optional, MutableSet

from theano.gof.graph import ancestors
from theano.gof.graph import stack_search
from theano.compile import SharedVariable
from theano.tensor import Tensor

from .util import get_default_varnames
import pymc3 as pm

# this is a placeholder for a better characterization of the type
# of variables in a model.
RV = Tensor


def powerset(iterable):
"""All *nonempty* subsets of an iterable.
Expand All @@ -27,37 +34,41 @@ def __init__(self, model):
self._deterministics = None

def get_deterministics(self, var):
"""Compute the deterministic nodes of the graph"""
"""Compute the deterministic nodes of the graph, **not** including var itself."""
deterministics = []
attrs = ('transformed', 'logpt')
for v in self.var_list:
if v != var and all(not hasattr(v, attr) for attr in attrs):
deterministics.append(v)
return deterministics

def _ancestors(self, var, func, blockers=None):
"""Get ancestors of a function that are also named PyMC3 variables"""
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])
def _get_ancestors(self, var, func) -> MutableSet[RV]:
"""Get all ancestors of a function, doing some accounting for deterministics.
"""

def _get_ancestors(self, var, func):
"""Get all ancestors of a function, doing some accounting for deterministics
# this contains all of the variables in the model EXCEPT var...
vars: MutableSet[RV] = set(self.var_list)
vars.remove(var)

blockers: MutableSet[RV] = set()
retval = set()
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
return None
elif node in vars:
blockers.add(node)
retval.add(node)
return None
elif node.owner:
blockers.add(node)
return reversed(node.owner.inputs)
else:
return None

Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
return only the inputs *to the deterministic*. However, if we pass in the
deterministic as a blocker, it will skip those nodes.
"""
deterministics = self.get_deterministics(var)
upstream = self._ancestors(var, func)

# Usual case
if upstream == self._ancestors(var, func, blockers=upstream):
return upstream
else: # deterministic accounting
for d in powerset(upstream):
blocked = self._ancestors(var, func, blockers=d)
if set(d) == blocked:
return d
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')
stack_search(start = deque([func]),
expand=_expand,
mode='bfs')
return retval

def _filter_parents(self, var, parents):
"""Get direct parents of a var, as strings"""
Expand Down