Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to disable bounds check for speed-up #4377

Merged
merged 6 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from pymc3.distributions.shape_utils import to_tuple
from pymc3.distributions.special import gammaln
from pymc3.model import modelcontext
from pymc3.theanof import floatX

f = floatX
Expand Down Expand Up @@ -67,6 +68,15 @@ def bound(logp, *conditions, **kwargs):
-------
logp with elements set to -inf where any condition is False
"""

# If called inside a model context, see if bounds check is disabled
try:
model = modelcontext(kwargs.get("model"))
if model.disable_bounds_check:
return logp
except TypeError:
pass

broadcast_conditions = kwargs.get("broadcast_conditions", True)

if broadcast_conditions:
Expand Down
10 changes: 9 additions & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,11 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
temporarily in the model context. See the documentation
of theano for a complete list. Set config key
``compute_test_value`` to `raise` if it is None.
disable_bounds_check: bool
Disable checks that ensure that input parameters to distributions
are in a valid range. If your model is built in a way where you
know your parameters can only take on valid values you can disable
this for increased speed.

Examples
--------
Expand Down Expand Up @@ -895,11 +900,14 @@ def __new__(cls, *args, **kwargs):
instance._theano_config = theano_config
return instance

def __init__(self, name="", model=None, theano_config=None, coords=None):
def __init__(
self, name="", model=None, theano_config=None, coords=None, disable_bounds_check=False
):
self.name = name
self.coords = {}
self.RV_dims = {}
self.add_coords(coords)
self.disable_bounds_check = disable_bounds_check

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
Expand Down