Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to disable bounds check for speed-up #4377

Merged
merged 6 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- Removed `theanof.set_theano_config` because it illegally changed Theano's internal state (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).

### New Features
- Option to `disable_bounds_check=True` when instantiating `pymc3.Model()` for faster sampling for models that cannot violate boundary constraints (which are most of them; see [#4377](https://github.com/pymc-devs/pymc3/pull/4377)).
- `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)).
- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234))

Expand Down
10 changes: 10 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from pymc3.distributions.shape_utils import to_tuple
from pymc3.distributions.special import gammaln
from pymc3.model import modelcontext
from pymc3.theanof import floatX

f = floatX
Expand Down Expand Up @@ -67,6 +68,15 @@ def bound(logp, *conditions, **kwargs):
-------
logp with elements set to -inf where any condition is False
"""

# If called inside a model context, see if bounds check is disabled
try:
model = modelcontext(kwargs.get("model"))
if model.disable_bounds_check:
return logp
except TypeError: # No model found
pass

broadcast_conditions = kwargs.get("broadcast_conditions", True)

if broadcast_conditions:
Expand Down
10 changes: 9 additions & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,11 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
temporarily in the model context. See the documentation
of theano for a complete list. Set config key
``compute_test_value`` to `raise` if it is None.
disable_bounds_check: bool
Disable checks that ensure that input parameters to distributions
are in a valid range. If your model is built in a way where you
know your parameters can only take on valid values you can disable
this for increased speed.

Examples
--------
Expand Down Expand Up @@ -895,11 +900,14 @@ def __new__(cls, *args, **kwargs):
instance._theano_config = theano_config
return instance

def __init__(self, name="", model=None, theano_config=None, coords=None):
def __init__(
self, name="", model=None, theano_config=None, coords=None, disable_bounds_check=False
):
self.name = name
self.coords = {}
self.RV_dims = {}
self.add_coords(coords)
self.disable_bounds_check = disable_bounds_check

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
Expand Down
7 changes: 7 additions & 0 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_bound():
assert np.prod(bound(logp, cond).eval()) == -np.inf


def test_bound_disabled():
with pm.Model(disable_bounds_check=True):
logp = tt.ones(3)
cond = np.array([1, 0, 1])
assert np.all(bound(logp, cond).eval() == logp.eval())


def test_alltrue_scalar():
assert alltrue_scalar([]).eval()
assert alltrue_scalar([True]).eval()
Expand Down