Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pm.Simulator (2nd attempt) #4877

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Distributions
distributions/discrete
distributions/multivariate
distributions/mixture
distributions/simulator
distributions/timeseries
distributions/transforms
distributions/utilities
12 changes: 12 additions & 0 deletions docs/source/api/distributions/simulator.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
**********
Simulator
**********

.. currentmodule:: pymc3.distributions.simulator
.. autosummary::

SimulatorRV
Simulator

.. automodule:: pymc3.distributions.simulator
:members:
21 changes: 16 additions & 5 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,26 @@ def rvs_to_value_vars(

"""

# Avoid circular dependency
from pymc3.distributions.simulator import SimulatorRV

def transform_replacements(var, replacements):
rv_var, rv_value_var = extract_rv_and_value_vars(var)

if rv_value_var is None:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []
# If RandomVariable does not have a value_var and corresponds to
# a SimulatorRV, we allow further replacements in upstream graph
if isinstance(rv_var.owner.op, SimulatorRV):
# First 3 inputs are just rng, dtype, and size, which don't
# need to be replaced.
return var.owner.inputs[3:]
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

else:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
Comment on lines +368 to +371
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What conclusion should a user make from from this warning?
Is it serious? If so we should raise. Otherwise maybe just _log.warn()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably raise

return []

transform = getattr(rv_value_var.tag, "transform", None)

Expand Down
3 changes: 2 additions & 1 deletion pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
Wishart,
WishartBartlett,
)
from pymc3.distributions.simulator import Simulator
from pymc3.distributions.simulator import Simulator, SimulatorRV
from pymc3.distributions.timeseries import (
AR,
AR1,
Expand Down Expand Up @@ -188,6 +188,7 @@
"Rice",
"Moyal",
"Simulator",
"SimulatorRV",
"BART",
"CAR",
"PolyaGamma",
Expand Down
49 changes: 7 additions & 42 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Optional

import aesara
import aesara.tensor as at

from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
Expand Down Expand Up @@ -325,47 +324,6 @@ def dist(
return rv_out


class NoDistribution(Distribution):
def __init__(
self,
shape,
dtype,
initval=None,
defaults=(),
parent_dist=None,
*args,
**kwargs,
):
super().__init__(
shape=shape, dtype=dtype, initval=initval, defaults=defaults, *args, **kwargs
)
self.parent_dist = parent_dist

def __getattr__(self, name):
# Do not use __getstate__ and __setstate__ from parent_dist
# to avoid infinite recursion during unpickling
if name.startswith("__"):
raise AttributeError("'NoDistribution' has no attribute '%s'" % name)
return getattr(self.parent_dist, name)

def logp(self, x):
"""Calculate log probability.

Parameters
----------
x: numeric
Value for which log-probability is calculated.

Returns
-------
TensorVariable
"""
return at.zeros_like(x)

def _distr_parameters_for_repr(self):
return []


class Discrete(Distribution):
"""Base class for discrete distributions"""

Expand All @@ -381,6 +339,13 @@ class Continuous(Distribution):
"""Base class for continuous distributions"""


class NoDistribution(Distribution):
"""Base class for artifical distributions

RandomVariables that share this type are allowed in logprob graphs
"""


class DensityDist(Distribution):
"""Distribution based on a given log density function.

Expand Down
Loading