Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added base class for variational methods #1600

Closed
wants to merge 87 commits into from
Closed
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
2d6fee8
Added mode argument to several step methods and advi to allow mode se…
fonnesbeck Nov 17, 2016
4e5b9c2
Fixed namespace bugs in mode attribute
fonnesbeck Nov 17, 2016
0ebaacd
Reverted function in delta_logp to not accept mode argument
fonnesbeck Nov 17, 2016
55c8ce6
ENH User model (#1525)
ferrine Nov 28, 2016
9ab04da
added new elbo implementation
ferrine Dec 4, 2016
9811220
Added mode argument to several step methods and advi to allow mode se…
fonnesbeck Nov 17, 2016
fbd1d5b
Fixed namespace bugs in mode attribute
fonnesbeck Nov 17, 2016
208aa79
Reverted function in delta_logp to not accept mode argument
fonnesbeck Nov 17, 2016
40d0146
ENH User model (#1525)
ferrine Nov 28, 2016
fc0673b
ENH User model (#1525)
ferrine Nov 28, 2016
ea82ebd
Refactor Hamiltonian methods into single class
Nov 8, 2016
140a80c
Reformat docs
Dec 3, 2016
168b113
added replacements class and mean field approximation
ferrine Dec 9, 2016
c1211a6
moved local to local constructor
ferrine Dec 9, 2016
9690562
property for deterministic replacements
ferrine Dec 9, 2016
34da7c8
refactored replacements to make them more unitary
ferrine Dec 9, 2016
07a248a
shape problem when sampling
ferrine Dec 8, 2016
889b50e
tests passed
ferrine Dec 9, 2016
0d486fb
deleted unused modules
ferrine Dec 13, 2016
125f6ad
added replacement names for global/local dict
ferrine Dec 13, 2016
1af91c0
Merge branch '3.1' into refactor_advi
ferrine Dec 13, 2016
69f07a1
refactored replacements
ferrine Dec 15, 2016
9614bf9
refactored replacements
ferrine Dec 15, 2016
32a2eb7
refactored GARCH and added Mv(Gaussian/StudentT)RandomWalk (#1603)
ferrine Dec 15, 2016
5e68b95
Merge branch '3.1' into refactor_advi
ferrine Dec 15, 2016
0f2c38f
added flatten_list
ferrine Dec 15, 2016
63e57d7
added tests
ferrine Dec 16, 2016
4d4cb82
refactored local/global dicts
ferrine Dec 16, 2016
82c7996
moved __type__ assignment to better place
ferrine Dec 16, 2016
2cd6bc5
Don't do replacements too early or else it will be not possible to tr…
ferrine Dec 16, 2016
4d810f2
refactored docs
ferrine Dec 16, 2016
16a226b
fixed memory consumption during test
ferrine Dec 18, 2016
d8e9886
set nmc samples to 1000 in test
ferrine Dec 18, 2016
1bb349e
optimized code a lot
ferrine Dec 19, 2016
87e7e2d
changed expectations to sampling, added docs
ferrine Dec 19, 2016
9eb79a0
code style
ferrine Dec 19, 2016
be1ca80
validate model
ferrine Dec 19, 2016
e8f6644
added tests for dynamic number of samples
ferrine Dec 19, 2016
4add3bc
added `set_params` method
ferrine Dec 19, 2016
43a8638
added `params` property
ferrine Dec 19, 2016
6a88fde
ENH KL-weighting
taku-y Nov 24, 2016
a3bad35
Fix bugs
taku-y Nov 26, 2016
7ed2cb5
Remove unnecessary comments
taku-y Nov 28, 2016
163b1be
Fix typo
taku-y Dec 5, 2016
fad9410
Minor fixes
taku-y Dec 7, 2016
e1a88e0
Check transformed RVs using hasattr
taku-y Dec 8, 2016
ae349e9
Update conv-vae notebook
taku-y Dec 15, 2016
9e237ef
Implementation of path derivative gradient estimator (NIPS 2016) #1615
ferrine Dec 20, 2016
63c1285
local vars nee this path trick too
ferrine Dec 20, 2016
02f5fa6
bug in local size calculation
ferrine Dec 21, 2016
8cc9558
bug in global subset view
ferrine Dec 21, 2016
4e302e6
improved performance
ferrine Dec 21, 2016
e5df6ee
changed the way for calling posterior
ferrine Dec 22, 2016
23ed175
deleted accidental added nuts file
ferrine Dec 22, 2016
7a7cdc3
Merge remote-tracking branch 'upstream/3.1' into refactor_advi
ferrine Dec 22, 2016
ac949d2
changed zero grad usage
ferrine Dec 22, 2016
26adf3b
refactor apply replacements
ferrine Dec 23, 2016
63000fb
added useful functions to replacements
ferrine Dec 24, 2016
5240260
added `approximate` function
ferrine Dec 24, 2016
7802a78
changed name MeanFieald to Advi
ferrine Dec 25, 2016
2407d78
added docs, renamed classes
ferrine Dec 25, 2016
fbf26d4
add deterministics to posterior to point function
ferrine Dec 25, 2016
c394d5e
trying to fix reweighting
ferrine Dec 26, 2016
2162d4c
weight log_p_W{local|global} correctly
ferrine Dec 26, 2016
7609f72
local and global weighting
ferrine Dec 27, 2016
d55d258
added docs
ferrine Dec 27, 2016
dca919c
preparing mnist vae, fixed bugs
ferrine Dec 27, 2016
cb2e219
Took in account suggestions for refactoring
ferrine Dec 30, 2016
d94e7e7
refactored dist math
ferrine Jan 3, 2017
8d1f088
Added mode argument to several step methods and advi to allow mode se…
fonnesbeck Nov 17, 2016
37843af
Created Generator Op with simple Test
ferrine Jan 11, 2017
3dc6f1b
added ndim test
ferrine Jan 11, 2017
a16512e
updated test
ferrine Jan 11, 2017
7127c23
updated test, added test value check
ferrine Jan 11, 2017
23b14ff
added test for replacing generator with shared variable
ferrine Jan 11, 2017
96cd5bb
added shortcut for generator op
ferrine Jan 11, 2017
633e4e9
refactored test
ferrine Jan 11, 2017
0629adc
added population kwarg (no tests yet)
ferrine Jan 11, 2017
06099a2
added population kwarg for free var(autoencoder case)
ferrine Jan 11, 2017
75a4849
Revert "Added mode argument to several step methods and advi to allow…
ferrine Jan 12, 2017
ff325d8
add docstring to generator Op
ferrine Jan 14, 2017
79ac934
rename population -> total_size
ferrine Jan 14, 2017
57dbe47
update docstrings in model
ferrine Jan 14, 2017
f8bce58
fix typo in `as_tensor` function
ferrine Jan 14, 2017
244bf21
Merge branch 'generator_op' into refactor_advi
ferrine Jan 14, 2017
8d91fee
add simple test for density scaling via `total_size`
ferrine Jan 17, 2017
1a9fa3d
raise an error when density scaling is done on scalar
ferrine Jan 17, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy import stats

from .dist_math import bound, factln, binomln, betaln, logpow
from .distribution import Discrete, draw_values, generate_samples
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled

__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'Poisson',
'NegativeBinomial', 'ConstantDist', 'Constant', 'ZeroInflatedPoisson',
Expand Down Expand Up @@ -250,7 +250,7 @@ def random(self, point=None, size=None, repeat=None):
dist_shape=self.shape,
size=size)
g[g == 0] = np.finfo(float).eps # Just in case
return stats.poisson.rvs(g)
return reshape_sampled(stats.poisson.rvs(g), size, self.shape)

def logp(self, value):
mu = self.mu
Expand Down Expand Up @@ -441,9 +441,11 @@ def logp(self, value):
c = self.c
return bound(0, tt.eq(value, c))


def ConstantDist(*args, **kwargs):
import warnings
warnings.warn("ConstantDist has been deprecated. In future, use Constant instead.",
DeprecationWarning)
DeprecationWarning)
return Constant(*args, **kwargs)


Expand Down Expand Up @@ -489,7 +491,8 @@ def random(self, point=None, size=None, repeat=None):
g = generate_samples(stats.poisson.rvs, theta,
dist_shape=self.shape,
size=size)
return g * (np.random.random(np.squeeze(g.shape)) < psi)
sampled = g * (np.random.random(np.squeeze(g.shape)) < psi)
return reshape_sampled(sampled, size, self.shape)

def logp(self, value):
return tt.switch(value > 0,
Expand Down Expand Up @@ -543,7 +546,8 @@ def random(self, point=None, size=None, repeat=None):
dist_shape=self.shape,
size=size)
g[g == 0] = np.finfo(float).eps # Just in case
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
sampled = stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
return reshape_sampled(sampled, size, self.shape)

def logp(self, value):
return tt.switch(value > 0,
Expand Down
38 changes: 38 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from .special import gammaln, multigammaln

c = - 0.5 * np.log(2 * np.pi)


def bound(logp, *conditions):
"""
Expand Down Expand Up @@ -77,3 +79,39 @@ def i1(x):
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
+ 14175 / (98304 * x**4)))


def sd2rho(sd):
"""sd -> rho
theano converter
mu + sd*e = mu + log(1+exp(rho))*e"""
return tt.log(tt.exp(sd) - 1)


def rho2sd(rho):
"""rho -> sd
theano converter
mu + sd*e = mu + log(1+exp(rho))*e"""
return tt.log1p(tt.exp(rho))


def kl_divergence_normal_pair(mu1, mu2, sd1, sd2):
elemwise_kl = (tt.log(sd2/sd1) +
(sd2**2 + (mu1-mu2)**2)/(2.*sd2**2) -
0.5)
return tt.sum(elemwise_kl)


def kl_divergence_normal_pair3(mu1, mu2, rho1, rho2):
sd1, sd2 = rho2sd(rho1), rho2sd(rho2)
return kl_divergence_normal_pair(mu1, mu2, sd1, sd2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not used, I can delete it


def log_normal(x, mean, std, eps=0.0):
std += eps
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2)


def log_normal3(x, mean, rho, eps=0.0):
std = rho2sd(rho)
return log_normal(x, mean, std, eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "3" mean?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this notion was used in other library:

  1. using sd
  2. using log sd
  3. using rho

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a more informative name for these functions that appending "3" to the name? Perhaps use rho in the name instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think docstring with cross references will be better.

31 changes: 20 additions & 11 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
__all__ = ['DensityDist', 'Distribution', 'Continuous',
'Discrete', 'NoDistribution', 'TensorType', 'draw_values']


class _Unpickling(object):
pass


class Distribution(object):
"""Statistical distribution"""
def __new__(cls, name, *args, **kwargs):
Expand Down Expand Up @@ -129,12 +131,10 @@ def __init__(self, logp, shape=(), dtype='float64', testval=0, *args, **kwargs):


class MultivariateContinuous(Continuous):

pass


class MultivariateDiscrete(Discrete):

pass


Expand Down Expand Up @@ -265,6 +265,22 @@ def broadcast_shapes(*args):
return tuple(x)


def infer_shape(shape):
try:
shape = tuple(shape or ())
except TypeError: # If size is an int
shape = tuple((shape,))
except ValueError: # If size is np.array
shape = tuple(shape)
return shape


def reshape_sampled(sampled, size, dist_shape):
dist_shape = infer_shape(dist_shape)
repeat_shape = infer_shape(size)
return np.reshape(sampled, repeat_shape + dist_shape)


def replicate_samples(generator, size, repeats, *args, **kwargs):
n = int(np.prod(repeats))
if n == 1:
Expand Down Expand Up @@ -326,10 +342,7 @@ def generate_samples(generator, *args, **kwargs):
else:
prefix_shape = tuple(dist_shape)

try:
repeat_shape = tuple(size or ())
except TypeError: # If size is an int
repeat_shape = tuple((size,))
repeat_shape = infer_shape(size)

if broadcast_shape == (1,) and prefix_shape == ():
if size is not None:
Expand All @@ -342,13 +355,9 @@ def generate_samples(generator, *args, **kwargs):
broadcast_shape,
repeat_shape + prefix_shape,
*args, **kwargs)
if broadcast_shape == (1,) and not prefix_shape == ():
samples = np.reshape(samples, repeat_shape + prefix_shape)
else:
samples = replicate_samples(generator,
broadcast_shape,
prefix_shape,
*args, **kwargs)
if broadcast_shape == (1,):
samples = np.reshape(samples, prefix_shape)
return samples
return reshape_sampled(samples, size, dist_shape)
96 changes: 90 additions & 6 deletions pymc3/distributions/timeseries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import theano.tensor as tt
from theano import scan

from .multivariate import get_tau_cov, MvNormal, MvStudentT
from .continuous import Normal, Flat
from .distribution import Continuous

__all__ = ['AR1', 'GaussianRandomWalk', 'GARCH11', 'EulerMaruyama']
__all__ = [
'AR1',
'GaussianRandomWalk',
'GARCH11',
'EulerMaruyama',
'MvGaussianRandomWalk',
'MvStudentTRandomWalk'
]


class AR1(Continuous):
Expand Down Expand Up @@ -108,7 +116,8 @@ def __init__(self, omega=None, alpha_1=None, beta_1=None,
self.initial_vol = initial_vol
self.mean = 0

def _get_volatility(self, x):
def get_volatility(self, x):
x = x[:-1]

def volatility_update(x, vol, w, a, b):
return tt.sqrt(w + a * tt.square(x) + b * tt.square(vol))
Expand All @@ -118,12 +127,11 @@ def volatility_update(x, vol, w, a, b):
outputs_info=[self.initial_vol],
non_sequences=[self.omega, self.alpha_1,
self.beta_1])
return vol
return tt.concatenate(self.initial_vol, vol)

def logp(self, x):
vol = self._get_volatility(x[:-1])
return (Normal.dist(0., sd=self.initial_vol).logp(x[0]) +
tt.sum(Normal.dist(0, sd=vol).logp(x[1:])))
vol = self.get_volatility(x)
return tt.sum(Normal.dist(0, sd=vol).logp(x))


class EulerMaruyama(Continuous):
Expand Down Expand Up @@ -151,3 +159,79 @@ def logp(self, x):
mu = xt + self.dt * f
sd = tt.sqrt(self.dt) * g
return tt.sum(Normal.dist(mu=mu, sd=sd).logp(x[1:]))


class MvGaussianRandomWalk(Continuous):
"""
Multivariate Random Walk with Normal innovations

Parameters
----------
mu : tensor
innovation drift, defaults to 0.0
cov : tensor
pos def matrix, innovation covariance matrix
tau : tensor
pos def matrix, innovation precision (alternative to specifying cov)
init : distribution
distribution for initial value (Defaults to Flat())
"""
def __init__(self, mu=0., cov=None, tau=None, init=Flat.dist(),
*args, **kwargs):
super(MvGaussianRandomWalk, self).__init__(*args, **kwargs)
tau, cov = get_tau_cov(mu, tau=tau, cov=cov)
self.tau = tau
self.cov = cov
self.mu = mu
self.init = init
self.mean = 0.

def logp(self, x):
tau = self.tau
mu = self.mu
init = self.init

x_im1 = x[:-1]
x_i = x[1:]

innov_like = MvNormal.dist(mu=x_im1 + mu, tau=tau).logp(x_i)
return init.logp(x[0]) + tt.sum(innov_like)


class MvStudentTRandomWalk(Continuous):
"""
Multivariate Random Walk with StudentT innovations

Parameters
----------
nu : degrees of freedom
mu : tensor
innovation drift, defaults to 0.0
cov : tensor
pos def matrix, innovation covariance matrix
tau : tensor
pos def matrix, innovation precision (alternative to specifying cov)
init : distribution
distribution for initial value (Defaults to Flat())
"""
def __init__(self, nu, mu=0., cov=None, tau=None, init=Flat.dist(),
*args, **kwargs):
super(MvStudentTRandomWalk, self).__init__(*args, **kwargs)
tau, cov = get_tau_cov(mu, tau=tau, cov=cov)
self.tau = tau
self.cov = cov
self.mu = mu
self.nu = nu
self.init = init
self.mean = 0.

def logp(self, x):
cov = self.cov
mu = self.mu
nu = self.nu
init = self.init

x_im1 = x[:-1]
x_i = x[1:]
innov_like = MvStudentT.dist(nu, cov, mu=x_im1 + mu).logp(x_i)
return init.logp(x[0]) + tt.sum(innov_like)
4 changes: 4 additions & 0 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ def invlogit(x, eps=sys.float_info.epsilon):

def logit(p):
return tt.log(p / (1 - p))


def flatten_list(tensors):
return tt.concatenate([var.ravel() for var in tensors])
Loading