Skip to content

Commit

Permalink
Update to limit support to univariate time series
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Oct 20, 2022
1 parent f8be99b commit 9b5fc40
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
8 changes: 4 additions & 4 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,8 +912,8 @@ class EulerMaruyama(Distribution):
sde_pars: tuple
parameters of the SDE, passed as ``*args`` to ``sde_fn``
init_dist : unnamed distribution, optional
Scalar or vector distribution for initial values. Unnamed refers to distributions
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
Scalar distribution for initial values. Unnamed refers to distributions created with
the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
Expand Down Expand Up @@ -953,9 +953,9 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
f"got {type(init_dist)}"
)
check_dist_not_registered(init_dist)
if init_dist.owner.op.ndim_supp > 1:
if init_dist.owner.op.ndim_supp > 0:
raise ValueError(
"Init distribution must have a scalar or vector support dimension, ",
"Init distribution must have a scalar support dimension, ",
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
)
else:
Expand Down
31 changes: 26 additions & 5 deletions pymc/tests/distributions/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,10 @@ def sde_fn(x, k, d, s):
sde_pars = [1.0, 2.0, 0.1]
sde_pars[batched_param] = sde_pars[batched_param] * param_val
with Model() as t0:
y = EulerMaruyama("y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs)
init_dist = pm.Normal.dist(0, 10, shape=(batch_size,))
y = EulerMaruyama(
"y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs
)

y_eval = draw(y, draws=2)
assert y_eval[0].shape == (batch_size, steps)
Expand All @@ -859,7 +862,15 @@ def sde_fn(x, k, d, s):
for i in range(batch_size):
sde_pars_slice = sde_pars.copy()
sde_pars_slice[batched_param] = sde_pars[batched_param][i]
EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars_slice, **kwargs)
init_dist = pm.Normal.dist(0, 10)
EulerMaruyama(
f"y_{i}",
dt=0.02,
sde_fn=sde_fn,
sde_pars=sde_pars_slice,
init_dist=init_dist,
**kwargs,
)

t0_init = t0.initial_point()
t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)}
Expand All @@ -872,7 +883,13 @@ def test_change_dist_size1(self):
def sde1(x, k, d, s):
return (k - d * x, s)

base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde1, sde_pars=(1, 2, 0.1), shape=(5, 10))
base_dist = EulerMaruyama.dist(
dt=0.01,
sde_fn=sde1,
sde_pars=(1, 2, 0.1),
init_dist=pm.Normal.dist(0, 10),
shape=(5, 10),
)

new_dist = change_dist_size(base_dist, (4,))
assert new_dist.eval().shape == (4, 10)
Expand All @@ -885,7 +902,9 @@ def sde2(p, s):
N = 500.0
return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N)

base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde2, sde_pars=(0.1,), shape=(3, 10))
base_dist = EulerMaruyama.dist(
dt=0.01, sde_fn=sde2, sde_pars=(0.1,), init_dist=pm.Normal.dist(0, 10), shape=(3, 10)
)

new_dist = change_dist_size(base_dist, (4,))
assert new_dist.eval().shape == (4, 10)
Expand Down Expand Up @@ -913,7 +932,9 @@ def _gen_sde_path(sde, pars, dt, n, x0):
# build model
with Model() as model:
lamh = Flat("lamh")
xh = EulerMaruyama("xh", dt, sde, (lamh,), steps=N, initval=x)
xh = EulerMaruyama(
"xh", dt, sde, (lamh,), steps=N, initval=x, init_dist=pm.Normal.dist(0, 10)
)
Normal("zh", mu=xh, sigma=sig2, observed=z)
# invert
with model:
Expand Down

0 comments on commit 9b5fc40

Please sign in to comment.