Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor EulerMaruyama to work in v4 #6227

Merged
merged 4 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 160 additions & 32 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings

from abc import ABCMeta
from typing import Optional
from typing import Callable, Optional

import aesara
import aesara.tensor as at
Expand All @@ -27,7 +27,6 @@
from aesara.tensor.random.op import RandomVariable

from pymc.aesaraf import constant_fold, floatX, intX
from pymc.distributions import distribution
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.distribution import (
Distribution,
Expand Down Expand Up @@ -461,7 +460,7 @@ class AR(Distribution):
process.
init_dist : unnamed distribution, optional
Scalar or vector distribution for initial values. Unnamed refers to distributions
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).

.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
Expand Down Expand Up @@ -881,7 +880,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps
return at.zeros_like(rv)


class EulerMaruyama(distribution.Continuous):
class EulerMaruyamaRV(SymbolicRandomVariable):
"""A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph."""

default_output = 1
dt: float
sde_fn: Callable
_print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}")

def __init__(self, *args, dt, sde_fn, **kwargs):
self.dt = dt
self.sde_fn = sde_fn
super().__init__(*args, **kwargs)

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


class EulerMaruyama(Distribution):
r"""
Stochastic differential equation discretized with the Euler-Maruyama method.

Expand All @@ -893,39 +911,149 @@ class EulerMaruyama(distribution.Continuous):
function returning the drift and diffusion coefficients of SDE
sde_pars: tuple
parameters of the SDE, passed as ``*args`` to ``sde_fn``
init_dist : unnamed distribution, optional
Scalar distribution for initial values. Unnamed refers to distributions created with
the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).

.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
"""

def __new__(cls, *args, **kwargs):
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
rv_type = EulerMaruyamaRV

def __new__(cls, name, dt, sde_fn, *args, steps=None, **kwargs):
dt = at.as_tensor_variable(floatX(dt))
steps = get_support_shape_1d(
support_shape=steps,
shape=None, # Shape will be checked in `cls.dist`
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
support_shape_offset=1,
)
return super().__new__(cls, name, dt, sde_fn, *args, steps=steps, **kwargs)

@classmethod
def dist(cls, *args, **kwargs):
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
steps = get_support_shape_1d(
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1
)
if steps is None:
raise ValueError("Must specify steps or shape parameter")
steps = at.as_tensor_variable(intX(steps), ndim=0)

def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
super().__init__(*args, **kwds)
self.dt = dt = at.as_tensor_variable(dt)
self.sde_fn = sde_fn
self.sde_pars = sde_pars
dt = at.as_tensor_variable(floatX(dt))
sde_pars = [at.as_tensor_variable(x) for x in sde_pars]

def logp(self, x):
"""
Calculate log-probability of EulerMaruyama distribution at specified value.
if init_dist is not None:
if not isinstance(init_dist, TensorVariable) or not isinstance(
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
):
raise ValueError(
f"Init dist must be a distribution created via the `.dist()` API, "
f"got {type(init_dist)}"
)
check_dist_not_registered(init_dist)
if init_dist.owner.op.ndim_supp > 0:
raise ValueError(
"Init distribution must have a scalar support dimension, ",
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
)
else:
warnings.warn(
"Initial distribution not specified, defaulting to "
"`Normal.dist(0, 100, shape=...)`. You can specify an init_dist "
"manually to suppress this warning.",
UserWarning,
)
init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape)
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

Parameters
----------
x: numeric
Value for which log-probability is calculated.
return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs)

Returns
-------
TensorVariable
"""
xt = x[:-1]
f, g = self.sde_fn(x[:-1], *self.sde_pars)
mu = xt + self.dt * f
sigma = at.sqrt(self.dt) * g
return at.sum(Normal.dist(mu=mu, sigma=sigma).logp(x[1:]))

def _distr_parameters_for_repr(self):
return ["dt"]
@classmethod
def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
# Init dist should have shape (*size,)
if size is not None:
batch_size = size
else:
batch_size = at.broadcast_shape(*sde_pars, init_dist)
init_dist = change_dist_size(init_dist, batch_size)
junpenglao marked this conversation as resolved.
Show resolved Hide resolved

# Create OpFromGraph representing random draws from SDE process
# Variables with underscore suffix are dummy inputs into the OpFromGraph
init_ = init_dist.type()
sde_pars_ = [x.type() for x in sde_pars]
steps_ = steps.type()

noise_rng = aesara.shared(np.random.default_rng())

def step(*prev_args):
prev_y, *prev_sde_pars, rng = prev_args
f, g = sde_fn(prev_y, *prev_sde_pars)
mu = prev_y + dt * f
sigma = at.sqrt(dt) * g
next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs
return next_y, {rng: next_rng}

y_t, innov_updates_ = aesara.scan(
fn=step,
outputs_info=[init_],
non_sequences=sde_pars_ + [noise_rng],
n_steps=steps_,
strict=True,
)
(noise_next_rng,) = tuple(innov_updates_.values())

sde_out_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle(
tuple(range(1, y_t.ndim)) + (0,)
)

eulermaruyama_op = EulerMaruyamaRV(
inputs=[init_, steps_] + sde_pars_,
outputs=[noise_next_rng, sde_out_],
dt=dt,
sde_fn=sde_fn,
ndim_supp=1,
)

eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars)
return eulermaruyama


@_change_dist_size.register(EulerMaruyamaRV)
def change_eulermaruyama_size(op, dist, new_size, expand=False):

if expand:
old_size = dist.shape[:-1]
new_size = tuple(new_size) + tuple(old_size)

init_dist, steps, *sde_pars, _ = dist.owner.inputs
return EulerMaruyama.rv_op(
init_dist,
steps,
sde_pars,
dt=op.dt,
sde_fn=op.sde_fn,
size=new_size,
)


@_logprob.register(EulerMaruyamaRV)
def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs):
(x,) = values
# noise arg is unused, but is needed to make the logp signature match the rv_op signature
*sde_pars, _ = sde_pars_noise_arg
# sde_fn is user provided and likely not broadcastable to additional time dimension,
# since the input x is now [..., t], we need to broadcast each input to [..., None]
# below as best effort attempt to make it work
sde_pars_broadcast = [x[..., None] for x in sde_pars]
xtm1 = x[..., :-1]
xt = x[..., 1:]
f, g = op.sde_fn(xtm1, *sde_pars_broadcast)
mu = xtm1 + op.dt * f
sigma = at.sqrt(op.dt) * g
# Compute and collapse logp across time dimension
sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1)
init_logp = logp(init_dist, x[..., 0])
return init_logp + sde_logp
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
151 changes: 117 additions & 34 deletions pymc/tests/distributions/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,37 +830,120 @@ def test_change_dist_size(self):
assert new_dist.eval().shape == (4, 3, 10)


def _gen_sde_path(sde, pars, dt, n, x0):
xs = [x0]
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
for i in range(n):
f, g = sde(xs[-1], *pars)
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
return np.array(xs)


@pytest.mark.xfail(reason="Euleryama not refactored", raises=NotImplementedError)
def test_linear():
lam = -0.78
sig2 = 5e-3
N = 300
dt = 1e-1
sde = lambda x, lam: (lam * x, sig2)
x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0))
z = x + np.random.randn(x.size) * sig2
# build model
with Model() as model:
lamh = Flat("lamh")
xh = EulerMaruyama("xh", dt, sde, (lamh,), shape=N + 1, initval=x)
Normal("zh", mu=xh, sigma=sig2, observed=z)
# invert
with model:
trace = sample(init="advi+adapt_diag", chains=1)

ppc = sample_posterior_predictive(trace, model=model)

p95 = [2.5, 97.5]
lo, hi = np.percentile(trace[lamh], p95, axis=0)
assert (lo < lam) and (lam < hi)
lo, hi = np.percentile(ppc["zh"], p95, axis=0)
assert ((lo < z) * (z < hi)).mean() > 0.95
class TestEulerMaruyama:
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("batched_param", [1, 2])
@pytest.mark.parametrize("explicit_shape", (True, False))
def test_batched_size(self, explicit_shape, batched_param):
steps, batch_size = 100, 5
param_val = np.square(np.random.randn(batch_size))
if explicit_shape:
kwargs = {"shape": (batch_size, steps)}
else:
kwargs = {"steps": steps - 1}

def sde_fn(x, k, d, s):
return (k - d * x, s)

sde_pars = [1.0, 2.0, 0.1]
sde_pars[batched_param] = sde_pars[batched_param] * param_val
with Model() as t0:
init_dist = pm.Normal.dist(0, 10, shape=(batch_size,))
y = EulerMaruyama(
"y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs
)

y_eval = draw(y, draws=2)
assert y_eval[0].shape == (batch_size, steps)
assert not np.any(np.isclose(y_eval[0], y_eval[1]))

if explicit_shape:
kwargs["shape"] = steps
with Model() as t1:
for i in range(batch_size):
sde_pars_slice = sde_pars.copy()
sde_pars_slice[batched_param] = sde_pars[batched_param][i]
init_dist = pm.Normal.dist(0, 10)
EulerMaruyama(
f"y_{i}",
dt=0.02,
sde_fn=sde_fn,
sde_pars=sde_pars_slice,
init_dist=init_dist,
**kwargs,
)

t0_init = t0.initial_point()
t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)}
np.testing.assert_allclose(
t0.compile_logp()(t0_init),
t1.compile_logp()(t1_init),
)

def test_change_dist_size1(self):
def sde1(x, k, d, s):
return (k - d * x, s)

base_dist = EulerMaruyama.dist(
dt=0.01,
sde_fn=sde1,
sde_pars=(1, 2, 0.1),
init_dist=pm.Normal.dist(0, 10),
shape=(5, 10),
)

new_dist = change_dist_size(base_dist, (4,))
assert new_dist.eval().shape == (4, 10)

new_dist = change_dist_size(base_dist, (4,), expand=True)
assert new_dist.eval().shape == (4, 5, 10)

def test_change_dist_size2(self):
def sde2(p, s):
N = 500.0
return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N)

base_dist = EulerMaruyama.dist(
dt=0.01, sde_fn=sde2, sde_pars=(0.1,), init_dist=pm.Normal.dist(0, 10), shape=(3, 10)
)

new_dist = change_dist_size(base_dist, (4,))
assert new_dist.eval().shape == (4, 10)

new_dist = change_dist_size(base_dist, (4,), expand=True)
assert new_dist.eval().shape == (4, 3, 10)

def test_linear_model(self):
lam = -0.78
sig2 = 5e-3
N = 300
dt = 1e-1

def _gen_sde_path(sde, pars, dt, n, x0):
xs = [x0]
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
for i in range(n):
f, g = sde(xs[-1], *pars)
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
return np.array(xs)

sde = lambda x, lam: (lam * x, sig2)
x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0))
z = x + np.random.randn(x.size) * sig2
# build model
with Model() as model:
lamh = Flat("lamh")
xh = EulerMaruyama(
"xh", dt, sde, (lamh,), steps=N, initval=x, init_dist=pm.Normal.dist(0, 10)
)
Normal("zh", mu=xh, sigma=sig2, observed=z)
# invert
with model:
trace = sample(chains=1)

ppc = sample_posterior_predictive(trace, model=model)

p95 = [2.5, 97.5]
lo, hi = np.percentile(trace.posterior["lamh"], p95, axis=[0, 1])
assert (lo < lam) and (lam < hi)
lo, hi = np.percentile(ppc.posterior_predictive["zh"], p95, axis=[0, 1])
assert ((lo < z) * (z < hi)).mean() > 0.95