Skip to content

Commit

Permalink
Replace custom memoize module with cachetools
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored and michaelosthege committed Mar 10, 2021
1 parent 055d112 commit 18cc84e
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 211 deletions.
6 changes: 3 additions & 3 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@
import theano.graph.basic
import theano.tensor as tt

from cachetools import LRUCache, cached
from theano import function

from pymc3.distributions.shape_utils import (
broadcast_dist_samples_shape,
get_broadcastable_dist_samples,
to_tuple,
)
from pymc3.memoize import memoize
from pymc3.model import (
ContextMeta,
FreeRV,
Expand All @@ -48,7 +48,7 @@
ObservedRV,
build_named_node_tree,
)
from pymc3.util import get_repr_for_variable, get_var_name
from pymc3.util import get_repr_for_variable, get_var_name, hash_key
from pymc3.vartypes import string_types, theano_constant

__all__ = [
Expand Down Expand Up @@ -840,7 +840,7 @@ def draw_values(params, point=None, size=None):
return [evaluated[j] for j in params] # set the order back


@memoize
@cached(LRUCache(128), key=hash_key)
def _compile_theano_function(param, vars, givens=None):
"""Compile theano function for a given parameter and input variables.
Expand Down
113 changes: 0 additions & 113 deletions pymc3/memoize.py

This file was deleted.

9 changes: 5 additions & 4 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import theano.sparse as sparse
import theano.tensor as tt

from cachetools import LRUCache, cachedmethod
from pandas import Series
from theano.compile import SharedVariable
from theano.graph.basic import Apply
Expand All @@ -36,10 +37,8 @@

from pymc3.blocking import ArrayOrdering, DictToArrayBijection
from pymc3.exceptions import ImputationWarning
from pymc3.math import flatten_list
from pymc3.memoize import WithMemoization, memoize
from pymc3.theanof import floatX, generator, gradient, hessian, inputvars
from pymc3.util import get_transformed_name, get_var_name
from pymc3.util import WithMemoization, get_transformed_name, get_var_name, hash_key
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter

__all__ = [
Expand Down Expand Up @@ -944,7 +943,9 @@ def isroot(self):
return self.parent is None

@property # type: ignore
@memoize(bound=True)
@cachedmethod(
lambda self: self.__dict__.setdefault("_bijection_cache", LRUCache(128)), key=hash_key
)
def bijection(self):
vars = inputvars(self.vars)

Expand Down
68 changes: 0 additions & 68 deletions pymc3/tests/test_memo.py

This file was deleted.

52 changes: 52 additions & 0 deletions pymc3/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import numpy as np
import pytest

from cachetools import cached
from numpy.testing import assert_almost_equal

import pymc3 as pm

from pymc3.distributions.transforms import Transform
from pymc3.tests.helpers import SeededTest
from pymc3.util import hash_key, hashable, locally_cachedmethod


class TestTransformName:
Expand Down Expand Up @@ -167,3 +169,53 @@ def test_dtype_error(self):
raise pm.exceptions.DtypeError("With types.", actual=int, expected=str)
assert "int" in exinfo.value.args[0] and "str" in exinfo.value.args[0]
pass


def test_hashing_of_rv_tuples():
obs = np.random.normal(-1, 0.1, size=10)
with pm.Model() as pmodel:
mu = pm.Normal("mu", 0, 1)
sd = pm.Gamma("sd", 1, 2)
dd = pm.DensityDist(
"dd",
pm.Normal.dist(mu, sd).logp,
random=pm.Normal.dist(mu, sd).random,
observed=obs,
)
for freerv in [mu, sd, dd] + pmodel.free_RVs:
for structure in [
freerv,
{"alpha": freerv, "omega": None},
[freerv, []],
(freerv, []),
]:
assert isinstance(hashable(structure), int)


def test_hash_key():
class Bad1:
def __hash__(self):
return 329

class Bad2:
def __hash__(self):
return 329

b1 = Bad1()
b2 = Bad2()

assert b1 != b2

@cached({}, key=hash_key)
def some_func(x):
return x

assert some_func(b1) != some_func(b2)

class TestClass:
@locally_cachedmethod
def some_method(self, x):
return x

tc = TestClass()
assert tc.some_method(b1) != tc.some_method(b2)
7 changes: 2 additions & 5 deletions pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import theano.tensor as tt

import pymc3 as pm
import pymc3.memoize
import pymc3.util

from pymc3.tests import models
Expand Down Expand Up @@ -757,22 +756,20 @@ def test_remove_scan_op():
def test_clear_cache():
import pickle

pymc3.memoize.clear_cache()
assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY)
with pm.Model():
pm.Normal("n", 0, 1)
inference = ADVI()
inference.fit(n=10)
assert any(len(c) != 0 for c in inference.approx._cache.values())
pymc3.memoize.clear_cache(inference.approx)
inference.approx._cache.clear()
# should not be cleared at this call
assert all(len(c) == 0 for c in inference.approx._cache.values())
new_a = pickle.loads(pickle.dumps(inference.approx))
assert not hasattr(new_a, "_cache")
inference_new = pm.KLqp(new_a)
inference_new.fit(n=10)
assert any(len(c) != 0 for c in inference_new.approx._cache.values())
pymc3.memoize.clear_cache(inference_new.approx)
inference_new.approx._cache.clear()
assert all(len(c) == 0 for c in inference_new.approx._cache.values())


Expand Down
Loading

0 comments on commit 18cc84e

Please sign in to comment.