diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index a8a272acaaa..75f5177cccd 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -6,6 +6,7 @@ - Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)). - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). - Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). +- Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129) ### Documentation diff --git a/docs/source/notebooks/lda-advi-aevb.ipynb b/docs/source/notebooks/lda-advi-aevb.ipynb index 9593fe32c4e..45a8f6ef0b2 100644 --- a/docs/source/notebooks/lda-advi-aevb.ipynb +++ b/docs/source/notebooks/lda-advi-aevb.ipynb @@ -24,14 +24,6 @@ "text": [ "env: THEANO_FLAGS=device=cpu,floatX=float64\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", - " from ._conv import register_converters as _register_converters\n" - ] } ], "source": [ @@ -52,7 +44,6 @@ "\n", "from pymc3 import Dirichlet\n", "from pymc3 import math as pmmath\n", - "from pymc3.distributions.transforms import t_stick_breaking\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n", "from theano import shared\n", @@ -77,22 +68,14 @@ "execution_count": 2, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading 20news dataset. This may take a few minutes.\n", - "Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ "Loading dataset...\n", - "done in 30.579s.\n", + "done in 1.724s.\n", "Extracting tf features for LDA...\n", - "done in 3.830s.\n" + "done in 2.177s.\n" ] } ], @@ -132,14 +115,12 @@ "outputs": [ { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -242,7 +223,7 @@ "## LDA model\n", "With the log-likelihood function, we can construct the probabilistic model for LDA. `doc_t` works as a placeholder to which documents in a mini-batch are set. \n", "\n", - "For ADVI, each of random variables $\\theta$ and $\\beta$, drawn from Dirichlet distributions, is transformed into unconstrained real coordinate space. To do this, by default, PyMC3 uses a centered stick-breaking transformation. Since these random variables are on a simplex, the dimension of the unconstrained coordinate space is the original dimension minus 1. For example, the dimension of $\\theta_{d}$ is the number of topics (`n_topics`) in the LDA model, thus the transformed space has dimension `(n_topics - 1)`. It shuold be noted that, in this example, we use `t_stick_breaking`, which is a numerically stable version of `stick_breaking` used by default. This is required to work ADVI for the LDA model. \n", + "For ADVI, each of random variables $\\theta$ and $\\beta$, drawn from Dirichlet distributions, is transformed into unconstrained real coordinate space. To do this, by default, PyMC3 uses an isometric logratio transformation. Since these random variables are on a simplex, the dimension of the unconstrained coordinate space is the original dimension minus 1. For example, the dimension of $\\theta_{d}$ is the number of topics (`n_topics`) in the LDA model, thus the transformed space has dimension `(n_topics - 1)`. \n", "\n", "The variational posterior on these transformed parameters is represented by a spherical Gaussian distributions (meanfield approximation). Thus, the number of variational parameters of $\\theta_{d}$, the latent variable for each document, is `2 * (n_topics - 1)` for means and standard deviations. \n", "\n", @@ -253,7 +234,16 @@ "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ".../lib/python3.7/site-packages/pymc3/data.py:305: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n", + " self.shared = theano.shared(data[in_memory_slc])\n" + ] + } + ], "source": [ "n_topics = 10\n", "# we have sparse dataset. It's better to have dence batch so that all words accure there\n", @@ -264,11 +254,11 @@ "doc_t = shared(docs_tr.toarray()[:minibatch_size])\n", "with pm.Model() as model:\n", " theta = Dirichlet('theta', a=pm.floatX((1.0 / n_topics) * np.ones((minibatch_size, n_topics))), \n", - " shape=(minibatch_size, n_topics), transform=t_stick_breaking(1e-9),\n", + " shape=(minibatch_size, n_topics),\n", " # do not forget scaling\n", " total_size=n_samples_tr)\n", " beta = Dirichlet('beta', a=pm.floatX((1.0 / n_topics) * np.ones((n_topics, n_words))), \n", - " shape=(n_topics, n_words), transform=t_stick_breaking(1e-9))\n", + " shape=(n_topics, n_words))\n", " # Note, that we defined likelihood with scaling, so here we need no additional `total_size` kwarg\n", " doc = pm.DensityDist('doc', logp_lda_doc(beta, theta), observed=doc_t)" ] @@ -342,7 +332,7 @@ { "data": { "text/plain": [ - "OrderedDict([(theta,\n", + "OrderedDict([(theta ~ Dirichlet(a=array),\n", " {'mu': Subtensor{::, :int64:}.0,\n", " 'rho': Subtensor{::, int64::}.0})])" ] @@ -400,18 +390,46 @@ "execution_count": 10, "metadata": {}, "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " 100.00% [10000/10000 08:11<00:00 Average Loss = 3.0171e+06]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stderr", "output_type": "stream", "text": [ - "Average Loss = 2.9855e+06: 100%|██████████| 10000/10000 [06:20<00:00, 25.59it/s]\n", - "Finished [100%]: Average Loss = 2.9886e+06\n" + "Finished [100%]: Average Loss = 3.0204e+06\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -458,14 +476,12 @@ "outputs": [ { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -497,16 +513,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Topic #0: people think god don just said know say time like\n", - "Topic #1: file use windows drive program using scsi does like software\n", - "Topic #2: ax max g9v b8f 75u a86 34u bhj pl 1d9\n", - "Topic #3: space key use information chip new encryption data public government\n", - "Topic #4: just like don good year time car think team better\n", - "Topic #5: know don like just does thanks think good ve need\n", - "Topic #6: 00 10 25 11 15 20 17 12 16 14\n", - "Topic #7: know like thanks new mail good just does don price\n", - "Topic #8: edu com mail cs like send just know don list\n", - "Topic #9: know like just don does thanks good edu mail new\n" + "Topic #0: car good just like use new used power don time\n", + "Topic #1: space key information use encryption new chip data public edu\n", + "Topic #2: don like just know people think ve time good want\n", + "Topic #3: year people said just team don like time think game\n", + "Topic #4: file use windows edu drive program scsi does using like\n", + "Topic #5: 00 10 25 15 11 17 20 16 12 14\n", + "Topic #6: ax max g9v b8f a86 75u pl bhj giz 1t\n", + "Topic #7: people god think don does just know believe say said\n", + "Topic #8: young just people know don like think does work good\n", + "Topic #9: like just know don think does people use time good\n" ] } ], @@ -536,20 +552,12 @@ "execution_count": 14, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/anaconda3/lib/python3.6/site-packages/sklearn/decomposition/online_lda.py:314: DeprecationWarning: n_topics has been renamed to n_components in version 0.19 and will be removed in 0.21\n", - " DeprecationWarning)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 49.7 s, sys: 538 ms, total: 50.3 s\n", - "Wall time: 37.9 s\n", + "CPU times: user 39.7 s, sys: 50.8 ms, total: 39.8 s\n", + "Wall time: 39.8 s\n", "Topic #0: people gun armenian war armenians turkish states said state 000\n", "Topic #1: government people law mr president use don think right public\n", "Topic #2: space science nasa program data research center output earth launch\n", @@ -566,7 +574,7 @@ "source": [ "from sklearn.decomposition import LatentDirichletAllocation\n", "\n", - "lda = LatentDirichletAllocation(n_topics=n_topics, max_iter=5,\n", + "lda = LatentDirichletAllocation(n_components=n_topics, max_iter=5,\n", " learning_method='online', learning_offset=50.,\n", " random_state=0)\n", "%time lda.fit(docs_tr)\n", @@ -685,9 +693,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 50.6 s, sys: 740 ms, total: 51.3 s\n", - "Wall time: 40.8 s\n", - "Predictive log prob (pm3) = -6.184892268516339\n" + "CPU times: user 7min 46s, sys: 7min 17s, total: 15min 3s\n", + "Wall time: 36 s\n", + "Predictive log prob (pm3) = -6.213144040486087\n" ] } ], @@ -712,9 +720,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1min 40s, sys: 1.26 s, total: 1min 41s\n", - "Wall time: 1min 17s\n", - "Predictive log prob (sklearn) = -6.014771065227896\n" + "CPU times: user 3min 44s, sys: 4min 4s, total: 7min 49s\n", + "Wall time: 1min 23s\n", + "Predictive log prob (sklearn) = -6.014771065227894\n" ] } ], @@ -752,20 +760,21 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "pymc3 3.8\n", - "arviz 0.8.3\n", - "numpy 1.17.5\n", - "last updated: Thu Jun 11 2020 \n", + "theano 1.0.5\n", + "seaborn 0.11.0\n", + "numpy 1.17.3\n", + "pymc3 3.9.3\n", + "last updated: Sat Sep 26 2020 \n", "\n", - "CPython 3.8.2\n", - "IPython 7.11.0\n", + "CPython 3.7.8\n", + "IPython 7.17.0\n", "watermark 2.0.2\n" ] } @@ -792,7 +801,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2" + "version": "3.7.8" }, "latex_envs": { "bibliofile": "biblio.bib", diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index d9381a48a82..efe321411cb 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import theano +import warnings import theano.tensor as tt from ..model import FreeRV from ..theanof import gradient, floatX from . import distribution -from ..math import logit, invlogit +from ..math import logit, invlogit, logsumexp from .distribution import draw_values import numpy as np from scipy.special import logit as nplogit @@ -36,7 +36,6 @@ "ordered", "log", "sum_to_1", - "t_stick_breaking", "circular", "CholeskyCovPacked", "Chain", @@ -106,7 +105,8 @@ def backward(self, z): raise NotImplementedError def jacobian_det(self, x): - """Calculates logarithm of the absolute value of the Jacobian determinant for input `x`. + """Calculates logarithm of the absolute value of the Jacobian determinant + of the backward transformation for input `x`. Parameters ---------- @@ -430,75 +430,56 @@ def jacobian_det(self, x): class StickBreaking(Transform): """ Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values. - Primarily borrowed from the Stan implementation. - - Parameters - ---------- - eps: float, positive value - A small value for numerical stability in invlogit. + This is a variant of the isometric logration transformation: + Egozcue, J.J., Pawlowsky-Glahn, V., Mateu-Figueras, G. et al. + Isometric Logratio Transformations for Compositional Data Analysis. + Mathematical Geology 35, 279–300 (2003). https://doi.org/10.1023/A:1023818214614 """ name = "stickbreaking" - def __init__(self, eps=floatX(np.finfo(theano.config.floatX).eps)): - self.eps = eps + def __init__(self, eps=None): + if eps is not None: + warnings.warn("The argument `eps` is deprecated and will not be used.", + DeprecationWarning) def forward(self, x_): x = x_.T - # reverse cumsum - x0 = x[:-1] - s = tt.extra_ops.cumsum(x0[::-1], 0)[::-1] + x[-1] - z = x0 / s - Km1 = x.shape[0] - 1 - k = tt.arange(Km1)[(slice(None),) + (None,) * (x.ndim - 1)] - eq_share = logit(1.0 / (Km1 + 1 - k).astype(str(x_.dtype))) - y = logit(z) - eq_share + n = x.shape[0] + lx = tt.log(x) + shift = tt.sum(lx, 0, keepdims=True) / n + y = lx[:-1] - shift return floatX(y.T) def forward_val(self, x_, point=None): x = x_.T - # reverse cumsum - x0 = x[:-1] - s = np.cumsum(x0[::-1], 0)[::-1] + x[-1] - z = x0 / s - Km1 = x.shape[0] - 1 - k = np.arange(Km1)[(slice(None),) + (None,) * (x.ndim - 1)] - eq_share = nplogit(1.0 / (Km1 + 1 - k).astype(str(x_.dtype))) - y = nplogit(z) - eq_share + n = x.shape[0] + lx = np.log(x) + shift = np.sum(lx, 0, keepdims=True) / n + y = lx[:-1] - shift return floatX(y.T) def backward(self, y_): y = y_.T - Km1 = y.shape[0] - k = tt.arange(Km1)[(slice(None),) + (None,) * (y.ndim - 1)] - eq_share = logit(1.0 / (Km1 + 1 - k).astype(str(y_.dtype))) - z = invlogit(y + eq_share, self.eps) - yl = tt.concatenate([z, tt.ones(y[:1].shape)]) - yu = tt.concatenate([tt.ones(y[:1].shape), 1 - z]) - S = tt.extra_ops.cumprod(yu, 0) - x = S * yl + y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)]) + # "softmax" with vector support and no deprication warning: + e_y = tt.exp(y - tt.max(y, 0, keepdims=True)) + x = e_y / tt.sum(e_y, 0, keepdims=True) return floatX(x.T) def jacobian_det(self, y_): y = y_.T - Km1 = y.shape[0] - k = tt.arange(Km1)[(slice(None),) + (None,) * (y.ndim - 1)] - eq_share = logit(1.0 / (Km1 + 1 - k).astype(str(y_.dtype))) - yl = y + eq_share - yu = tt.concatenate([tt.ones(y[:1].shape), 1 - invlogit(yl, self.eps)]) - S = tt.extra_ops.cumprod(yu, 0) - return tt.sum(tt.log(S[:-1]) - tt.log1p(tt.exp(yl)) - tt.log1p(tt.exp(-yl)), 0).T + Km1 = y.shape[0] + 1 + sy = tt.sum(y, 0, keepdims=True) + r = tt.concatenate([y+sy, tt.zeros(sy.shape)]) + sr = logsumexp(r, 0, keepdims=True) + d = tt.log(Km1) + (Km1*sy) - (Km1*sr) + return tt.sum(d, 0).T stick_breaking = StickBreaking() -def t_stick_breaking(eps: float) -> StickBreaking: - """Return a new :class:`StickBreaking` transform with specified eps(ilon), - instead of the default.""" - return StickBreaking(eps) - - class Circular(ElemwiseTransform): """Transforms a linear space into a circular one. """ diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index 4984eccf7b0..fafb160406c 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -84,7 +84,12 @@ def check_jacobian_det(transform, domain, computed_ljd(yval), tol) -def test_simplex(): +def test_stickbreaking(): + with pytest.warns( + DeprecationWarning, + match="The argument `eps` is deprecated and will not be used." + ): + tr.StickBreaking(eps=1e-9) check_vector_transform(tr.stick_breaking, Simplex(2)) check_vector_transform(tr.stick_breaking, Simplex(4)) @@ -92,7 +97,7 @@ def test_simplex(): 3, 2), constructor=tt.dmatrix, test=np.zeros((2, 2))) -def test_simplex_bounds(): +def test_stickbreaking_bounds(): vals = get_values(tr.stick_breaking, Vector(R, 2), tt.dvector, np.array([0, 0])) @@ -103,6 +108,16 @@ def test_simplex_bounds(): check_jacobian_det(tr.stick_breaking, Vector( R, 2), tt.dvector, np.array([0, 0]), lambda x: x[:-1]) +def test_stickbreaking_accuracy(): + val = np.array([-30]) + x = tt.dvector('x') + x.tag.test_value = val + identity_f = theano.function( + [x], + tr.stick_breaking.forward(tr.stick_breaking.backward(x)) + ) + close_to(val, identity_f(val), tol) + def test_sum_to_1(): check_vector_transform(tr.sum_to_1, Simplex(2))