Skip to content

Commit

Permalink
Add the transformation between the inverse gamma and the exponential
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 26, 2022
1 parent 5b5ae2f commit 460266f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
51 changes: 51 additions & 0 deletions aemcmc/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,54 @@ def location_scale_transform(in_expr, out_expr):
eq(out_expr, noncentered_et),
location_scale_family(distribution_lv),
)


def invgamma_exponential(invgamma_expr, invexponential_expr):
r"""Produce a goal that represents the relation between the inverse gamma distribution
and the inverse of an exponential distribution.
.. math::
\begin{equation*}
\frac{
X \sim \operatorname{Gamma^{-1}}\left(1, c\right)
}{
Y = 1 / X, \quad
Y \sim \operatorname{Exp}\left(c\right)
}
\end{equation*}
TODO: This is a particular case of a more general relation between the inverse gamma
and the gamma distribution (of which the exponential distribution is a special case).
We should implement this more general relation, and the special case separately in the
future.
Parameters
----------
invgmamma_expr
An expression that represents a random variable with an inverse gamma
distribution with a shape parameter equal to 1.
inexponential_expr
An expression that represents the inverse of a random variable with an
exponential distribution.
"""
c_lv = var()
rng_lv, size_lv, dtype_lv = var(), var(), var()

invgamma_et = etuple(
etuplize(at.random.invgamma), rng_lv, size_lv, dtype_lv, at.as_tensor(1.0), c_lv
)

exponential_et = etuple(
etuplize(at.random.exponential),
c_lv,
rng=rng_lv,
size=size_lv,
dtype=dtype_lv,
)
invexponential_et = etuple(at.true_div, at.as_tensor(1.0), exponential_et)

return lall(
eq(invgamma_expr, invgamma_et), eq(invexponential_expr, invexponential_et)
)
34 changes: 33 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.kanren import KanrenRelationSub

from aemcmc.transforms import location_scale_transform
from aemcmc.transforms import invgamma_exponential, location_scale_transform


def test_normal_scale_loc_transform_lift():
Expand Down Expand Up @@ -45,3 +45,35 @@ def test_normal_scale_loc_transform_sink():
)[0]

assert isinstance(res.owner.op, type(at.random.normal))


def test_invgamma_to_exp():

srng = at.random.RandomStream(0)
c_at = at.scalar()
X_rv = srng.invgamma(1.0, c_at)

fgraph = FunctionGraph(outputs=[X_rv], clone=False)
res = KanrenRelationSub(invgamma_exponential).transform(
fgraph, fgraph.outputs[0].owner
)[0]

assert isinstance(res.owner.op, type(at.true_div))
assert isinstance(res.owner.inputs[1].owner.op, type(at.random.exponential))


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_invgamma_from_exp():

srng = at.random.RandomStream(0)
c_at = at.scalar()
X_rv = 1.0 / srng.exponential(c_at)

fgraph = FunctionGraph(outputs=[X_rv], clone=False)
res = KanrenRelationSub(lambda x, y: invgamma_exponential(y, x)).transform(
fgraph, fgraph.outputs[0].owner
)[0]

assert isinstance(res.owner.op, type(at.random.inversegamma))

0 comments on commit 460266f

Please sign in to comment.