Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pm.Simulator (WIP) #4802

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/api/distributions/simulator.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
**********
Simulator
**********

.. currentmodule:: pymc3.distributions.simulator
.. autosummary::

Simulator

.. automodule:: pymc3.distributions.simulator
:members:
20 changes: 15 additions & 5 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,21 @@ def transform_replacements(var, replacements):
rv_var, rv_value_var = extract_rv_and_value_vars(var)

if rv_value_var is None:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []

# TODO: Importing at top is creating a circular dependency
from pymc3.distributions.simulator import SimulatorRV

# If orphan RandomVariable is a SimulatorRV, we allow for further
# replacements in upstream graph
if isinstance(rv_var.owner.op, SimulatorRV):
return var.owner.inputs[3:]

else:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []

transform = getattr(rv_value_var.tag, "transform", None)

Expand Down
46 changes: 4 additions & 42 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Optional

import aesara
import aesara.tensor as at
import dill

from aesara.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -386,47 +385,6 @@ def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
__latex__ = _repr_latex_


class NoDistribution(Distribution):
def __init__(
self,
shape,
dtype,
initval=None,
defaults=(),
parent_dist=None,
*args,
**kwargs,
):
super().__init__(
shape=shape, dtype=dtype, initval=initval, defaults=defaults, *args, **kwargs
)
self.parent_dist = parent_dist

def __getattr__(self, name):
# Do not use __getstate__ and __setstate__ from parent_dist
# to avoid infinite recursion during unpickling
if name.startswith("__"):
raise AttributeError("'NoDistribution' has no attribute '%s'" % name)
return getattr(self.parent_dist, name)

def logp(self, x):
"""Calculate log probability.

Parameters
----------
x: numeric
Value for which log-probability is calculated.

Returns
-------
TensorVariable
"""
return at.zeros_like(x)

def _distr_parameters_for_repr(self):
return []


class Discrete(Distribution):
"""Base class for discrete distributions"""

Expand All @@ -442,6 +400,10 @@ class Continuous(Distribution):
"""Base class for continuous distributions"""


class NoDistribution(Distribution):
"""Base class for artifical distributions"""


class DensityDist(Distribution):
"""Distribution based on a given log density function.

Expand Down
235 changes: 160 additions & 75 deletions pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,52 @@

import logging

import aesara
import aesara.tensor as at
import numpy as np

from aesara.graph.op import Op
from aesara.tensor.random.op import RandomVariable
from scipy.spatial import cKDTree

from pymc3.aesaraf import floatX
from pymc3.distributions.distribution import NoDistribution
from pymc3.distributions.logp import _logp

__all__ = ["Simulator"]

_log = logging.getLogger("pymc3")


class SimulatorRV(RandomVariable):
"""A placeholder for Simulator RVs"""

name = "SimulatorRV"
_print_name = ("Simulator", "\\operatorname{Simulator}")
fn = None
epsilon = None
distance = None
sum_stat = None

@classmethod
def rng_fn(cls, *args, **kwargs):
if cls.fn is None:
raise ValueError(f"fn was not defined for {cls}")
return cls.fn(*args, **kwargs)

@classmethod
def _distance(cls, epsilon, value, sim_value):
if cls.distance is None:
raise ValueError(f"distance function was not defined for {cls}")
return cls.distance(epsilon, value, sim_value)

@classmethod
def _sum_stat(cls, value):
if cls.sum_stat is None:
raise ValueError(f"sum_stat function was not defined for {cls}")
return cls.sum_stat(value)


class Simulator(NoDistribution):
r"""
Define a simulator, from a Python function, to be used in ABC methods.
Expand Down Expand Up @@ -54,86 +89,104 @@ class Simulator(NoDistribution):
Arguments and keywords arguments that the function takes.
"""

def __init__(
self,
function,
*args,
def __new__(
cls,
name,
fn,
*,
params=None,
distance="gaussian",
sum_stat="identity",
epsilon=1,
observed=None,
ndim_supp=0,
ndims_params=None,
dtype="floatX",
**kwargs,
):
self.function = function
self.params = params
observed = self.data
self.epsilon = epsilon

if distance == "gaussian":
self.distance = gaussian
elif distance == "laplace":
self.distance = laplace
elif distance == "kullback_leibler":
self.distance = KullbackLiebler(observed)
if sum_stat != "identity":
_log.info(f"Automatically setting sum_stat to identity as expected by {distance}")
sum_stat = "identity"
elif hasattr(distance, "__call__"):
self.distance = distance
else:
raise ValueError(f"The distance metric {distance} is not implemented")

if sum_stat == "identity":
self.sum_stat = identity
elif sum_stat == "sort":
self.sum_stat = np.sort
elif sum_stat == "mean":
self.sum_stat = np.mean
elif sum_stat == "median":
self.sum_stat = np.median
elif hasattr(sum_stat, "__call__"):
self.sum_stat = sum_stat
else:
raise ValueError(f"The summary statistics {sum_stat} is not implemented")

super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs)

def random(self, point=None, size=None):
"""
Draw random values from Simulator.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be conditioned (uses default
point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not specified).

Returns
-------
array
"""
# size = to_tuple(size)
# params = draw_values([*self.params], point=point, size=size)
# if len(size) == 0:
# return self.function(*params)
# else:
# return np.array([self.function(*params) for _ in range(size[0])])

def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
dist = self
name = name
function = dist.function.__name__
params = ", ".join([var.name for var in dist.params])
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)

if "latex" in formatting:
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
else:
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"

if not isinstance(distance, Op):
if distance == "gaussian":
distance = gaussian
elif distance == "laplace":
distance = laplace
elif distance == "kullback_leibler":
raise NotImplementedError("KL not refactored yet")
# TODO: Wrap KL in aesara OP
# distance = KullbackLiebler(observed)
# if sum_stat != "identity":
# _log.info(f"Automatically setting sum_stat to identity as expected by {distance}")
# sum_stat = "identity"
elif callable(distance):
distance = create_distance_op_from_fn(distance)
else:
raise ValueError(f"The distance metric {distance} is not implemented")

if not isinstance(sum_stat, Op):
if sum_stat == "identity":
sum_stat = identity
elif sum_stat == "sort":
sum_stat = at.sort
elif sum_stat == "mean":
sum_stat = at.mean
elif sum_stat == "median":
sum_stat = at.median
elif callable(sum_stat):
sum_stat = create_sum_stat_op_from_fn(sum_stat)
else:
raise ValueError(f"The summary statistics {sum_stat} is not implemented")

epsilon = at.as_tensor_variable(floatX(epsilon))

if params is None:
params = []

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(params)

sim_op = type(
f"Simulator_{name}",
(SimulatorRV,),
dict(
name="Simulator",
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
inplace=False,
# Specifc to Simulator
fn=fn,
distance=distance,
sum_stat=sum_stat,
epsilon=epsilon,
),
)()

# Register custom logp
rv_type = type(sim_op)

@_logp.register(rv_type)
def logp(op, sim_rv, rvs_to_values, *sim_params, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

locally defined functions often cause problems with pickling.

Also it looks like this logp uses variables from the __new__ scope. Doesn't this, in combination with the register lead to problems when having more than one Simulator?

Copy link
Member Author

@ricardoV94 ricardoV94 Jun 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about pickling. Depends on the point at which the function is used (after obtaining the logp graph, the function is not needed anymore).

The registration part should be fine because I am creating a new type which has a unique name. I also have a test for multiple Simulators with different methods and it seems to be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also try to add a layer of indirection when creating the simulator so that I can attach what are now the local variables to the tag, in which case a single logp function / dispatcher would be enough.

value_var = rvs_to_values.get(sim_rv, sim_rv)
return Simulator.logp(
value_var,
sim_rv,
)

cls.rv_op = sim_op
return super().__new__(cls, name, params, observed=observed, **kwargs)

@classmethod
def logp(cls, value, sim_rv):
# Create a new simulatorRV identically to the original one
sim_op = sim_rv.owner.op
sim_data = at.as_tensor_variable(sim_op.make_node(*sim_rv.owner.inputs))
sim_data.name = "sim_data"
return sim_op._distance(
sim_op.epsilon,
sim_op._sum_stat(value),
sim_op._sum_stat(sim_data),
)


def identity(x):
Expand All @@ -148,7 +201,7 @@ def gaussian(epsilon, obs_data, sim_data):

def laplace(epsilon, obs_data, sim_data):
"""Laplace kernel."""
return -np.abs((obs_data - sim_data) / epsilon)
return -at.abs_((obs_data - sim_data) / epsilon)


class KullbackLiebler:
Expand All @@ -169,3 +222,35 @@ def __call__(self, epsilon, obs_data, sim_data):
sim_data = sim_data[:, None]
nu_d, _ = cKDTree(sim_data).query(self.obs_data, 1)
return self.d_n * np.sum(-np.log(nu_d / self.rho_d) / epsilon) + self.log_r


def create_sum_stat_op_from_fn(fn):
class SumStat(Op):
if aesara.config.floatX == "float64":
itypes = [at.dvector]
otypes = [at.dvector]
else:
itypes = [at.fvector]
otypes = [at.fvector]

def perform(self, node, inputs, outputs):
(x,) = inputs
outputs[0][0] = np.atleast_1d(fn(x)).astype(aesara.config.floatX)

return SumStat()


def create_distance_op_from_fn(fn):
class Distance(Op):
if aesara.config.floatX == "float64":
itypes = [at.dscalar, at.dvector, at.dvector]
otypes = [at.dvector]
else:
itypes = [at.fscalar, at.fvector, at.fvector]
otypes = [at.fvector]

def perform(self, node, inputs, outputs):
eps, obs_data, sim_data = inputs
outputs[0][0] = np.atleast_1d(fn(eps, obs_data, sim_data)).astype(aesara.config.floatX)

return Distance()
Loading