Skip to content

Commit

Permalink
Add fallback to prior when moment is not implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 13, 2021
1 parent 216dcd5 commit d84c718
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
15 changes: 14 additions & 1 deletion pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings

from typing import Callable, Dict, List, Optional, Sequence, Set, Union

Expand Down Expand Up @@ -269,7 +270,19 @@ def make_initial_point_expression(

if isinstance(strategy, str):
if strategy == "moment":
value = get_moment(variable)
try:
value = get_moment(variable)
except NotImplementedError:
warnings.warn(
f"Moment not defined for variable {variable} of type "
f"{variable.owner.op.__class__.__name__}, defaulting to "
f"a draw from the prior. This can lead to difficulties "
f"during tuning. You can manually define an initval or "
f"implement a get_moment dispatched function for this "
f"distribution.",
UserWarning,
)
value = variable
elif strategy == "prior":
value = variable
else:
Expand Down
26 changes: 26 additions & 0 deletions pymc/tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import pytest

from aesara.tensor.random.op import RandomVariable

import pymc as pm

from pymc.distributions.distribution import get_moment
Expand Down Expand Up @@ -255,6 +257,30 @@ def test_moment_from_dims(self, rv_cls):
assert tuple(get_moment(rv).shape.eval()) == (4, 3)
pass

def test_moment_not_implemented_fallback(self):
class MyNormalRV(RandomVariable):
name = "my_normal"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"

@classmethod
def rng_fn(cls, rng, mu, sigma, size):
return np.pi

class MyNormalDistribution(pm.Normal):
rv_op = MyNormalRV()

with pm.Model() as m:
x = MyNormalDistribution("x", 0, 1, initval="moment")

with pytest.warns(
UserWarning, match="Moment not defined for variable x of type MyNormalRV"
):
res = m.recompute_initial_point()

assert np.isclose(res["x"], np.pi)


def test_pickling_issue_5090():
with pm.Model() as model:
Expand Down

0 comments on commit d84c718

Please sign in to comment.