Skip to content

Commit

Permalink
Add flag to disable bounds check for speed-up (#4377)
Browse files Browse the repository at this point in the history
* Add flag to disable bounds check.

* Move check for model and flag into single line as per @colcarrolls's suggestion.

* modelcontext raises a TypeError if no model is found. Catch that.

* Add comment.

* Add mention in release-notes.

* Add test for when boundaries are disabled.
  • Loading branch information
twiecki authored Dec 23, 2020
1 parent 0402aab commit 37239fc
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 1 deletion.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- Removed `theanof.set_theano_config` because it illegally changed Theano's internal state (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).

### New Features
- Option to `disable_bounds_check=True` when instantiating `pymc3.Model()` for faster sampling for models that cannot violate boundary constraints (which are most of them; see [#4377](https://github.com/pymc-devs/pymc3/pull/4377)).
- `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)).
- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234))

Expand Down
10 changes: 10 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from pymc3.distributions.shape_utils import to_tuple
from pymc3.distributions.special import gammaln
from pymc3.model import modelcontext
from pymc3.theanof import floatX

f = floatX
Expand Down Expand Up @@ -67,6 +68,15 @@ def bound(logp, *conditions, **kwargs):
-------
logp with elements set to -inf where any condition is False
"""

# If called inside a model context, see if bounds check is disabled
try:
model = modelcontext(kwargs.get("model"))
if model.disable_bounds_check:
return logp
except TypeError: # No model found
pass

broadcast_conditions = kwargs.get("broadcast_conditions", True)

if broadcast_conditions:
Expand Down
10 changes: 9 additions & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,11 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
temporarily in the model context. See the documentation
of theano for a complete list. Set config key
``compute_test_value`` to `raise` if it is None.
disable_bounds_check: bool
Disable checks that ensure that input parameters to distributions
are in a valid range. If your model is built in a way where you
know your parameters can only take on valid values you can disable
this for increased speed.
Examples
--------
Expand Down Expand Up @@ -895,11 +900,14 @@ def __new__(cls, *args, **kwargs):
instance._theano_config = theano_config
return instance

def __init__(self, name="", model=None, theano_config=None, coords=None):
def __init__(
self, name="", model=None, theano_config=None, coords=None, disable_bounds_check=False
):
self.name = name
self.coords = {}
self.RV_dims = {}
self.add_coords(coords)
self.disable_bounds_check = disable_bounds_check

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
Expand Down
7 changes: 7 additions & 0 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_bound():
assert np.prod(bound(logp, cond).eval()) == -np.inf


def test_bound_disabled():
with pm.Model(disable_bounds_check=True):
logp = tt.ones(3)
cond = np.array([1, 0, 1])
assert np.all(bound(logp, cond).eval() == logp.eval())


def test_alltrue_scalar():
assert alltrue_scalar([]).eval()
assert alltrue_scalar([True]).eval()
Expand Down

0 comments on commit 37239fc

Please sign in to comment.