-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Interpolated distribution class #2163
Changes from all commits
e2c14c7
aa0fbea
c520245
4a381f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,19 +10,20 @@ | |
import numpy as np | ||
import theano.tensor as tt | ||
from scipy import stats | ||
from scipy.interpolate import InterpolatedUnivariateSpline | ||
import warnings | ||
|
||
from pymc3.theanof import floatX | ||
from . import transforms | ||
|
||
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise | ||
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise, DifferentiableSplineWrapper | ||
from .distribution import Continuous, draw_values, generate_samples, Bound | ||
|
||
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace', | ||
'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull', | ||
'HalfStudentT', 'StudentTpos', 'Lognormal', 'ChiSquared', | ||
'HalfNormal', 'Wald', 'Pareto', 'InverseGamma', 'ExGaussian', | ||
'VonMises', 'SkewNormal'] | ||
'VonMises', 'SkewNormal', 'Interpolated'] | ||
|
||
|
||
class PositiveContinuous(Continuous): | ||
|
@@ -1389,3 +1390,71 @@ def random(self, point=None, size=None, repeat=None): | |
def logp(self, value): | ||
scaled = (value - self.mu) / self.beta | ||
return bound(-scaled - tt.exp(-scaled) - tt.log(self.beta), self.beta > 0) | ||
|
||
class Interpolated(Continuous): | ||
R""" | ||
Probability distribution defined as a linear interpolation of | ||
of a set of points and values of probability density function | ||
evaluated on them. | ||
|
||
The points are not variables, but plain array-like objects, so | ||
they are constant and cannot be sampled. | ||
|
||
======== ========================================= | ||
Support :math:`x \in [x_points[0], x_points[-1]]` | ||
======== ========================================= | ||
|
||
Parameters | ||
---------- | ||
x_points : array-like | ||
A monotonically growing list of values | ||
pdf_points : array-like | ||
Probability density function evaluated at points from `x` | ||
""" | ||
|
||
def __init__(self, x_points, pdf_points, transform='interval', | ||
*args, **kwargs): | ||
if transform == 'interval': | ||
transform = transforms.interval(x_points[0], x_points[-1]) | ||
super(Interpolated, self).__init__(transform=transform, | ||
*args, **kwargs) | ||
|
||
interp = InterpolatedUnivariateSpline(x_points, pdf_points, k=1, ext='zeros') | ||
Z = interp.integral(x_points[0], x_points[-1]) | ||
|
||
self.Z = tt.as_tensor_variable(Z) | ||
self.interp_op = DifferentiableSplineWrapper(interp) | ||
self.x_points = x_points | ||
self.pdf_points = pdf_points / Z | ||
self.cdf_points = interp.antiderivative()(x_points) / Z | ||
|
||
self.median = self._argcdf(0.5) | ||
|
||
def _argcdf(self, p): | ||
pdf = self.pdf_points | ||
cdf = self.cdf_points | ||
x = self.x_points | ||
|
||
index = np.searchsorted(cdf, p) - 1 | ||
slope = (pdf[index + 1] - pdf[index]) / (x[index + 1] - x[index]) | ||
|
||
return x[index] + np.where( | ||
np.abs(slope) <= 1e-8, | ||
np.where( | ||
np.abs(pdf[index]) <= 1e-8, | ||
np.zeros(index.shape), | ||
(p - cdf[index]) / pdf[index] | ||
), | ||
(-pdf[index] + np.sqrt(pdf[index] ** 2 + 2 * slope * (p - cdf[index]))) / slope | ||
) | ||
|
||
def _random(self, size=None): | ||
return self._argcdf(np.random.uniform(size=size)) | ||
|
||
def random(self, point=None, size=None, repeat=None): | ||
return generate_samples(self._random, | ||
dist_shape=self.shape, | ||
size=size) | ||
|
||
def logp(self, value): | ||
return tt.log(self.interp_op(value) / self.Z) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when I ran locally on your example in #2146 I actually find your first implementation faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That was because initially I incorrectly defined the gradient in #2146 , so HMC didn't work well for the second version (see this comment). In this implementation calculation of the gradient is fixed, so it is even faster to put division by the normalization constant here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right, it should be faster then. I will double check on my side then. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we allow the user to define
k
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two major problems with non-linear interpolations:
integral()
method of SciPy spline classes.random()
method can be implemented efficiently only for linear interpolation because inverse CDF can be expressed in closed form only for piecewise-quadratic CDF (well, it is possible to try to do the same for piecewise-cubic CDF using Cardano formula, but I'm not subscribing to it :) For higher-order polynomial interpolations it would be necessary to find inverses numerically, using an iterative process like Newton method.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep the first point is a valid concern. I think we can merge this now and add higher order polynomial support in the future.