diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 2ef30ff36ed..57028d90d57 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -15,7 +15,7 @@ import warnings from abc import ABCMeta -from typing import Optional +from typing import Callable, Optional import aesara import aesara.tensor as at @@ -27,7 +27,6 @@ from aesara.tensor.random.op import RandomVariable from pymc.aesaraf import constant_fold, floatX, intX -from pymc.distributions import distribution from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.distribution import ( Distribution, @@ -461,7 +460,7 @@ class AR(Distribution): process. init_dist : unnamed distribution, optional Scalar or vector distribution for initial values. Unnamed refers to distributions - created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order). + created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order). If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...). .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. @@ -881,7 +880,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps return at.zeros_like(rv) -class EulerMaruyama(distribution.Continuous): +class EulerMaruyamaRV(SymbolicRandomVariable): + """A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph.""" + + default_output = 1 + dt: float + sde_fn: Callable + _print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}") + + def __init__(self, *args, dt, sde_fn, **kwargs): + self.dt = dt + self.sde_fn = sde_fn + super().__init__(*args, **kwargs) + + def update(self, node: Node): + """Return the update mapping for the noise RV.""" + # Since noise is a shared variable it shows up as the last node input + return {node.inputs[-1]: node.outputs[0]} + + +class EulerMaruyama(Distribution): r""" Stochastic differential equation discretized with the Euler-Maruyama method. @@ -893,39 +911,149 @@ class EulerMaruyama(distribution.Continuous): function returning the drift and diffusion coefficients of SDE sde_pars: tuple parameters of the SDE, passed as ``*args`` to ``sde_fn`` + init_dist : unnamed distribution, optional + Scalar distribution for initial values. Unnamed refers to distributions created with + the ``.dist()`` API. Distributions should have shape (*shape[:-1]). + If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...). + + .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. """ - def __new__(cls, *args, **kwargs): - raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.") + rv_type = EulerMaruyamaRV + + def __new__(cls, name, dt, sde_fn, *args, steps=None, **kwargs): + dt = at.as_tensor_variable(floatX(dt)) + steps = get_support_shape_1d( + support_shape=steps, + shape=None, # Shape will be checked in `cls.dist` + dims=kwargs.get("dims", None), + observed=kwargs.get("observed", None), + support_shape_offset=1, + ) + return super().__new__(cls, name, dt, sde_fn, *args, steps=steps, **kwargs) @classmethod - def dist(cls, *args, **kwargs): - raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.") + def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1 + ) + if steps is None: + raise ValueError("Must specify steps or shape parameter") + steps = at.as_tensor_variable(intX(steps), ndim=0) - def __init__(self, dt, sde_fn, sde_pars, *args, **kwds): - super().__init__(*args, **kwds) - self.dt = dt = at.as_tensor_variable(dt) - self.sde_fn = sde_fn - self.sde_pars = sde_pars + dt = at.as_tensor_variable(floatX(dt)) + sde_pars = [at.as_tensor_variable(x) for x in sde_pars] - def logp(self, x): - """ - Calculate log-probability of EulerMaruyama distribution at specified value. + if init_dist is not None: + if not isinstance(init_dist, TensorVariable) or not isinstance( + init_dist.owner.op, (RandomVariable, SymbolicRandomVariable) + ): + raise ValueError( + f"Init dist must be a distribution created via the `.dist()` API, " + f"got {type(init_dist)}" + ) + check_dist_not_registered(init_dist) + if init_dist.owner.op.ndim_supp > 0: + raise ValueError( + "Init distribution must have a scalar support dimension, ", + f"got ndim_supp={init_dist.owner.op.ndim_supp}.", + ) + else: + warnings.warn( + "Initial distribution not specified, defaulting to " + "`Normal.dist(0, 100, shape=...)`. You can specify an init_dist " + "manually to suppress this warning.", + UserWarning, + ) + init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape) + # Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term + init_dist = ignore_logprob(init_dist) - Parameters - ---------- - x: numeric - Value for which log-probability is calculated. + return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs) - Returns - ------- - TensorVariable - """ - xt = x[:-1] - f, g = self.sde_fn(x[:-1], *self.sde_pars) - mu = xt + self.dt * f - sigma = at.sqrt(self.dt) * g - return at.sum(Normal.dist(mu=mu, sigma=sigma).logp(x[1:])) - - def _distr_parameters_for_repr(self): - return ["dt"] + @classmethod + def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None): + # Init dist should have shape (*size,) + if size is not None: + batch_size = size + else: + batch_size = at.broadcast_shape(*sde_pars, init_dist) + init_dist = change_dist_size(init_dist, batch_size) + + # Create OpFromGraph representing random draws from SDE process + # Variables with underscore suffix are dummy inputs into the OpFromGraph + init_ = init_dist.type() + sde_pars_ = [x.type() for x in sde_pars] + steps_ = steps.type() + + noise_rng = aesara.shared(np.random.default_rng()) + + def step(*prev_args): + prev_y, *prev_sde_pars, rng = prev_args + f, g = sde_fn(prev_y, *prev_sde_pars) + mu = prev_y + dt * f + sigma = at.sqrt(dt) * g + next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs + return next_y, {rng: next_rng} + + y_t, innov_updates_ = aesara.scan( + fn=step, + outputs_info=[init_], + non_sequences=sde_pars_ + [noise_rng], + n_steps=steps_, + strict=True, + ) + (noise_next_rng,) = tuple(innov_updates_.values()) + + sde_out_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle( + tuple(range(1, y_t.ndim)) + (0,) + ) + + eulermaruyama_op = EulerMaruyamaRV( + inputs=[init_, steps_] + sde_pars_, + outputs=[noise_next_rng, sde_out_], + dt=dt, + sde_fn=sde_fn, + ndim_supp=1, + ) + + eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars) + return eulermaruyama + + +@_change_dist_size.register(EulerMaruyamaRV) +def change_eulermaruyama_size(op, dist, new_size, expand=False): + + if expand: + old_size = dist.shape[:-1] + new_size = tuple(new_size) + tuple(old_size) + + init_dist, steps, *sde_pars, _ = dist.owner.inputs + return EulerMaruyama.rv_op( + init_dist, + steps, + sde_pars, + dt=op.dt, + sde_fn=op.sde_fn, + size=new_size, + ) + + +@_logprob.register(EulerMaruyamaRV) +def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs): + (x,) = values + # noise arg is unused, but is needed to make the logp signature match the rv_op signature + *sde_pars, _ = sde_pars_noise_arg + # sde_fn is user provided and likely not broadcastable to additional time dimension, + # since the input x is now [..., t], we need to broadcast each input to [..., None] + # below as best effort attempt to make it work + sde_pars_broadcast = [x[..., None] for x in sde_pars] + xtm1 = x[..., :-1] + xt = x[..., 1:] + f, g = op.sde_fn(xtm1, *sde_pars_broadcast) + mu = xtm1 + op.dt * f + sigma = at.sqrt(op.dt) * g + # Compute and collapse logp across time dimension + sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1) + init_logp = logp(init_dist, x[..., 0]) + return init_logp + sde_logp diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 2cc8e2e091f..e4dd45900c9 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -830,37 +830,120 @@ def test_change_dist_size(self): assert new_dist.eval().shape == (4, 3, 10) -def _gen_sde_path(sde, pars, dt, n, x0): - xs = [x0] - wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) - for i in range(n): - f, g = sde(xs[-1], *pars) - xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) - return np.array(xs) - - -@pytest.mark.xfail(reason="Euleryama not refactored", raises=NotImplementedError) -def test_linear(): - lam = -0.78 - sig2 = 5e-3 - N = 300 - dt = 1e-1 - sde = lambda x, lam: (lam * x, sig2) - x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0)) - z = x + np.random.randn(x.size) * sig2 - # build model - with Model() as model: - lamh = Flat("lamh") - xh = EulerMaruyama("xh", dt, sde, (lamh,), shape=N + 1, initval=x) - Normal("zh", mu=xh, sigma=sig2, observed=z) - # invert - with model: - trace = sample(init="advi+adapt_diag", chains=1) - - ppc = sample_posterior_predictive(trace, model=model) - - p95 = [2.5, 97.5] - lo, hi = np.percentile(trace[lamh], p95, axis=0) - assert (lo < lam) and (lam < hi) - lo, hi = np.percentile(ppc["zh"], p95, axis=0) - assert ((lo < z) * (z < hi)).mean() > 0.95 +class TestEulerMaruyama: + @pytest.mark.parametrize("batched_param", [1, 2]) + @pytest.mark.parametrize("explicit_shape", (True, False)) + def test_batched_size(self, explicit_shape, batched_param): + steps, batch_size = 100, 5 + param_val = np.square(np.random.randn(batch_size)) + if explicit_shape: + kwargs = {"shape": (batch_size, steps)} + else: + kwargs = {"steps": steps - 1} + + def sde_fn(x, k, d, s): + return (k - d * x, s) + + sde_pars = [1.0, 2.0, 0.1] + sde_pars[batched_param] = sde_pars[batched_param] * param_val + with Model() as t0: + init_dist = pm.Normal.dist(0, 10, shape=(batch_size,)) + y = EulerMaruyama( + "y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs + ) + + y_eval = draw(y, draws=2) + assert y_eval[0].shape == (batch_size, steps) + assert not np.any(np.isclose(y_eval[0], y_eval[1])) + + if explicit_shape: + kwargs["shape"] = steps + with Model() as t1: + for i in range(batch_size): + sde_pars_slice = sde_pars.copy() + sde_pars_slice[batched_param] = sde_pars[batched_param][i] + init_dist = pm.Normal.dist(0, 10) + EulerMaruyama( + f"y_{i}", + dt=0.02, + sde_fn=sde_fn, + sde_pars=sde_pars_slice, + init_dist=init_dist, + **kwargs, + ) + + t0_init = t0.initial_point() + t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)} + np.testing.assert_allclose( + t0.compile_logp()(t0_init), + t1.compile_logp()(t1_init), + ) + + def test_change_dist_size1(self): + def sde1(x, k, d, s): + return (k - d * x, s) + + base_dist = EulerMaruyama.dist( + dt=0.01, + sde_fn=sde1, + sde_pars=(1, 2, 0.1), + init_dist=pm.Normal.dist(0, 10), + shape=(5, 10), + ) + + new_dist = change_dist_size(base_dist, (4,)) + assert new_dist.eval().shape == (4, 10) + + new_dist = change_dist_size(base_dist, (4,), expand=True) + assert new_dist.eval().shape == (4, 5, 10) + + def test_change_dist_size2(self): + def sde2(p, s): + N = 500.0 + return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N) + + base_dist = EulerMaruyama.dist( + dt=0.01, sde_fn=sde2, sde_pars=(0.1,), init_dist=pm.Normal.dist(0, 10), shape=(3, 10) + ) + + new_dist = change_dist_size(base_dist, (4,)) + assert new_dist.eval().shape == (4, 10) + + new_dist = change_dist_size(base_dist, (4,), expand=True) + assert new_dist.eval().shape == (4, 3, 10) + + def test_linear_model(self): + lam = -0.78 + sig2 = 5e-3 + N = 300 + dt = 1e-1 + + def _gen_sde_path(sde, pars, dt, n, x0): + xs = [x0] + wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) + for i in range(n): + f, g = sde(xs[-1], *pars) + xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) + return np.array(xs) + + sde = lambda x, lam: (lam * x, sig2) + x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0)) + z = x + np.random.randn(x.size) * sig2 + # build model + with Model() as model: + lamh = Flat("lamh") + xh = EulerMaruyama( + "xh", dt, sde, (lamh,), steps=N, initval=x, init_dist=pm.Normal.dist(0, 10) + ) + Normal("zh", mu=xh, sigma=sig2, observed=z) + # invert + with model: + trace = sample(chains=1) + + ppc = sample_posterior_predictive(trace, model=model) + + p95 = [2.5, 97.5] + lo, hi = np.percentile(trace.posterior["lamh"], p95, axis=[0, 1]) + assert (lo < lam) and (lam < hi) + lo, hi = np.percentile(ppc.posterior_predictive["zh"], p95, axis=[0, 1]) + assert ((lo < z) * (z < hi)).mean() > 0.95