diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index c24a9d9df6e..3ba2d0a0405 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -32,6 +32,7 @@ import theano.graph.basic import theano.tensor as tt +from cachetools import LRUCache, cached from theano import function from pymc3.distributions.shape_utils import ( @@ -39,7 +40,6 @@ get_broadcastable_dist_samples, to_tuple, ) -from pymc3.memoize import memoize from pymc3.model import ( ContextMeta, FreeRV, @@ -48,7 +48,7 @@ ObservedRV, build_named_node_tree, ) -from pymc3.util import get_repr_for_variable, get_var_name +from pymc3.util import get_repr_for_variable, get_var_name, hash_key from pymc3.vartypes import string_types, theano_constant __all__ = [ @@ -840,7 +840,7 @@ def draw_values(params, point=None, size=None): return [evaluated[j] for j in params] # set the order back -@memoize +@cached(LRUCache(128), key=hash_key) def _compile_theano_function(param, vars, givens=None): """Compile theano function for a given parameter and input variables. diff --git a/pymc3/memoize.py b/pymc3/memoize.py deleted file mode 100644 index cbe791f10ce..00000000000 --- a/pymc3/memoize.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import functools - -import dill - -from pymc3.util import biwrap - -CACHE_REGISTRY = [] - - -@biwrap -def memoize(obj, bound=False): - """ - Decorator to apply memoization to expensive functions. - It uses a custom `hashable` helper function to hash typically unhashable Python objects. - - Parameters - ---------- - obj : callable - the function to apply the caching to - bound : bool - indicates if the [obj] is a bound method (self as first argument) - For bound methods, the cache is kept in a `_cache` attribute on [self]. - """ - # this is declared not to be a bound method, so just attach new attr to obj - if not bound: - obj.cache = {} - CACHE_REGISTRY.append(obj.cache) - - @functools.wraps(obj) - def memoizer(*args, **kwargs): - if not bound: - key = (hashable(args), hashable(kwargs)) - cache = obj.cache - else: - # bound methods have self as first argument, remove it to compute key - key = (hashable(args[1:]), hashable(kwargs)) - if not hasattr(args[0], "_cache"): - setattr(args[0], "_cache", collections.defaultdict(dict)) - # do not add to cache registry - cache = getattr(args[0], "_cache")[obj.__name__] - if key not in cache: - cache[key] = obj(*args, **kwargs) - - return cache[key] - - return memoizer - - -def clear_cache(obj=None): - if obj is None: - for c in CACHE_REGISTRY: - c.clear() - else: - if isinstance(obj, WithMemoization): - for v in getattr(obj, "_cache", {}).values(): - v.clear() - else: - obj.cache.clear() - - -class WithMemoization: - def __hash__(self): - return hash(id(self)) - - def __getstate__(self): - state = self.__dict__.copy() - state.pop("_cache", None) - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - -def hashable(a) -> int: - """ - Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function. - Lists and tuples are hashed based on their elements. - """ - if isinstance(a, dict): - # first hash the keys and values with hashable - # then hash the tuple of int-tuples with the builtin - return hash(tuple((hashable(k), hashable(v)) for k, v in a.items())) - if isinstance(a, (tuple, list)): - # lists are mutable and not hashable by default - # for memoization, we need the hash to depend on the items - return hash(tuple(hashable(i) for i in a)) - try: - return hash(a) - except TypeError: - pass - # Not hashable >>> - try: - return hash(dill.dumps(a)) - except Exception: - if hasattr(a, "__dict__"): - return hashable(a.__dict__) - else: - return id(a) diff --git a/pymc3/model.py b/pymc3/model.py index 349affcfa01..dff1e3b78bf 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -27,6 +27,7 @@ import theano.sparse as sparse import theano.tensor as tt +from cachetools import LRUCache, cachedmethod from pandas import Series from theano.compile import SharedVariable from theano.graph.basic import Apply @@ -36,10 +37,8 @@ from pymc3.blocking import ArrayOrdering, DictToArrayBijection from pymc3.exceptions import ImputationWarning -from pymc3.math import flatten_list -from pymc3.memoize import WithMemoization, memoize from pymc3.theanof import floatX, generator, gradient, hessian, inputvars -from pymc3.util import get_transformed_name, get_var_name +from pymc3.util import WithMemoization, get_transformed_name, get_var_name, hash_key from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter __all__ = [ @@ -944,7 +943,9 @@ def isroot(self): return self.parent is None @property # type: ignore - @memoize(bound=True) + @cachedmethod( + lambda self: self.__dict__.setdefault("_bijection_cache", LRUCache(128)), key=hash_key + ) def bijection(self): vars = inputvars(self.vars) diff --git a/pymc3/tests/test_memo.py b/pymc3/tests/test_memo.py deleted file mode 100644 index 6653662e32e..00000000000 --- a/pymc3/tests/test_memo.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np - -import pymc3 as pm - -from pymc3 import memoize - - -def test_memo(): - def fun(inputs, suffix="_a"): - return str(inputs) + str(suffix) - - inputs = ["i1", "i2"] - assert fun(inputs) == "['i1', 'i2']_a" - assert fun(inputs, "_b") == "['i1', 'i2']_b" - - funmem = memoize.memoize(fun) - assert hasattr(fun, "cache") - assert isinstance(fun.cache, dict) - assert len(fun.cache) == 0 - - # call the memoized function with a list input - # and check the size of the cache! - assert funmem(inputs) == "['i1', 'i2']_a" - assert funmem(inputs) == "['i1', 'i2']_a" - assert len(fun.cache) == 1 - assert funmem(inputs, "_b") == "['i1', 'i2']_b" - assert funmem(inputs, "_b") == "['i1', 'i2']_b" - assert len(fun.cache) == 2 - - # add items to the inputs list (the list instance remains identical !!) - inputs.append("i3") - assert funmem(inputs) == "['i1', 'i2', 'i3']_a" - assert funmem(inputs) == "['i1', 'i2', 'i3']_a" - assert len(fun.cache) == 3 - - -def test_hashing_of_rv_tuples(): - obs = np.random.normal(-1, 0.1, size=10) - with pm.Model() as pmodel: - mu = pm.Normal("mu", 0, 1) - sd = pm.Gamma("sd", 1, 2) - dd = pm.DensityDist( - "dd", - pm.Normal.dist(mu, sd).logp, - random=pm.Normal.dist(mu, sd).random, - observed=obs, - ) - for freerv in [mu, sd, dd] + pmodel.free_RVs: - for structure in [ - freerv, - {"alpha": freerv, "omega": None}, - [freerv, []], - (freerv, []), - ]: - assert isinstance(memoize.hashable(structure), int) diff --git a/pymc3/tests/test_util.py b/pymc3/tests/test_util.py index adb334fb8af..05b6bdf52da 100644 --- a/pymc3/tests/test_util.py +++ b/pymc3/tests/test_util.py @@ -15,12 +15,14 @@ import numpy as np import pytest +from cachetools import cached from numpy.testing import assert_almost_equal import pymc3 as pm from pymc3.distributions.transforms import Transform from pymc3.tests.helpers import SeededTest +from pymc3.util import hash_key, hashable, locally_cachedmethod class TestTransformName: @@ -167,3 +169,53 @@ def test_dtype_error(self): raise pm.exceptions.DtypeError("With types.", actual=int, expected=str) assert "int" in exinfo.value.args[0] and "str" in exinfo.value.args[0] pass + + +def test_hashing_of_rv_tuples(): + obs = np.random.normal(-1, 0.1, size=10) + with pm.Model() as pmodel: + mu = pm.Normal("mu", 0, 1) + sd = pm.Gamma("sd", 1, 2) + dd = pm.DensityDist( + "dd", + pm.Normal.dist(mu, sd).logp, + random=pm.Normal.dist(mu, sd).random, + observed=obs, + ) + for freerv in [mu, sd, dd] + pmodel.free_RVs: + for structure in [ + freerv, + {"alpha": freerv, "omega": None}, + [freerv, []], + (freerv, []), + ]: + assert isinstance(hashable(structure), int) + + +def test_hash_key(): + class Bad1: + def __hash__(self): + return 329 + + class Bad2: + def __hash__(self): + return 329 + + b1 = Bad1() + b2 = Bad2() + + assert b1 != b2 + + @cached({}, key=hash_key) + def some_func(x): + return x + + assert some_func(b1) != some_func(b2) + + class TestClass: + @locally_cachedmethod + def some_method(self, x): + return x + + tc = TestClass() + assert tc.some_method(b1) != tc.some_method(b2) diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 1ef9b616290..1f1e0352777 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -22,7 +22,6 @@ import theano.tensor as tt import pymc3 as pm -import pymc3.memoize import pymc3.util from pymc3.tests import models @@ -757,14 +756,12 @@ def test_remove_scan_op(): def test_clear_cache(): import pickle - pymc3.memoize.clear_cache() - assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY) with pm.Model(): pm.Normal("n", 0, 1) inference = ADVI() inference.fit(n=10) assert any(len(c) != 0 for c in inference.approx._cache.values()) - pymc3.memoize.clear_cache(inference.approx) + inference.approx._cache.clear() # should not be cleared at this call assert all(len(c) == 0 for c in inference.approx._cache.values()) new_a = pickle.loads(pickle.dumps(inference.approx)) @@ -772,7 +769,7 @@ def test_clear_cache(): inference_new = pm.KLqp(new_a) inference_new.fit(n=10) assert any(len(c) != 0 for c in inference_new.approx._cache.values()) - pymc3.memoize.clear_cache(inference_new.approx) + inference_new.approx._cache.clear() assert all(len(c) == 0 for c in inference_new.approx._cache.values()) diff --git a/pymc3/util.py b/pymc3/util.py index 84b4f6c3e5f..dbd6219a1e3 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -19,9 +19,11 @@ from typing import Dict, List, Tuple, Union import arviz +import dill import numpy as np import xarray +from cachetools import LRUCache, cachedmethod from theano.tensor import TensorVariable from pymc3.exceptions import SamplingError @@ -304,3 +306,76 @@ def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tupl nchains = coords["chain"].sizes["chain"] nsamples = coords["draw"].sizes["draw"] return nchains, nsamples + + +def hashable(a=None) -> int: + """ + Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function. + Lists and tuples are hashed based on their elements. + """ + if isinstance(a, dict): + # first hash the keys and values with hashable + # then hash the tuple of int-tuples with the builtin + return hash(tuple((hashable(k), hashable(v)) for k, v in a.items())) + if isinstance(a, (tuple, list)): + # lists are mutable and not hashable by default + # for memoization, we need the hash to depend on the items + return hash(tuple(hashable(i) for i in a)) + try: + return hash(a) + except TypeError: + pass + # Not hashable >>> + try: + return hash(dill.dumps(a)) + except Exception: + if hasattr(a, "__dict__"): + return hashable(a.__dict__) + else: + return id(a) + + +def hash_key(*args, **kwargs): + return tuple(HashableWrapper(a) for a in args + tuple(kwargs.items())) + + +class HashableWrapper: + __slots__ = ("obj",) + + def __init__(self, obj): + self.obj = obj + + def __hash__(self): + return hashable(self.obj) + + def __eq__(self, other): + return self.obj == other + + def __repr__(self): + return f"{type(self).__name__}({self.obj})" + + +class WithMemoization: + def __hash__(self): + return hash(id(self)) + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_cache", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + +def locally_cachedmethod(f): + + from collections import defaultdict + + def self_cache_fn(f_name): + def cf(self): + return self.__dict__.setdefault("_cache", defaultdict(lambda: LRUCache(128)))[f_name] + + return cf + + return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f) diff --git a/pymc3/variational/flows.py b/pymc3/variational/flows.py index 601c7351fa7..bc87af5eba2 100644 --- a/pymc3/variational/flows.py +++ b/pymc3/variational/flows.py @@ -18,7 +18,7 @@ from theano import tensor as tt from pymc3.distributions.dist_math import rho2sigma -from pymc3.memoize import WithMemoization +from pymc3.util import WithMemoization from pymc3.variational import opvi from pymc3.variational.opvi import collect_shared_to_list, node_property diff --git a/pymc3/variational/opvi.py b/pymc3/variational/opvi.py index ebf4a9cda84..814f49b3633 100644 --- a/pymc3/variational/opvi.py +++ b/pymc3/variational/opvi.py @@ -57,10 +57,14 @@ from pymc3.backends import NDArray from pymc3.blocking import ArrayOrdering, DictToArrayBijection, VarMap -from pymc3.memoize import WithMemoization, memoize from pymc3.model import modelcontext from pymc3.theanof import identity, tt_rng -from pymc3.util import get_default_varnames, get_transformed +from pymc3.util import ( + WithMemoization, + get_default_varnames, + get_transformed, + locally_cachedmethod, +) from pymc3.variational.updates import adagrad_window __all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"] @@ -111,21 +115,18 @@ def inner(*args, **kwargs): def node_property(f): """A shortcut for wrapping method to accessible tensor""" + if isinstance(f, str): def wrapper(fn): - return property( - memoize( - theano.config.change_flags(compute_test_value="off")(append_name(f)(fn)), - bound=True, - ) - ) + ff = append_name(f)(fn) + f_ = theano.config.change_flags(compute_test_value="off")(ff) + return property(locally_cachedmethod(f_)) return wrapper else: - return property( - memoize(theano.config.change_flags(compute_test_value="off")(f), bound=True) - ) + f_ = theano.config.change_flags(compute_test_value="off")(f) + return property(locally_cachedmethod(f_)) @theano.config.change_flags(compute_test_value="ignore") @@ -1586,9 +1587,7 @@ def vars_names(vs): raise KeyError("%r not found" % name) return found - @property - @memoize(bound=True) - @theano.config.change_flags(compute_test_value="off") + @node_property def sample_dict_fn(self): s = tt.iscalar() names = [v.name for v in self.model.free_RVs] diff --git a/pymc3/variational/stein.py b/pymc3/variational/stein.py index ca9a9249106..216a9d8276b 100644 --- a/pymc3/variational/stein.py +++ b/pymc3/variational/stein.py @@ -15,8 +15,8 @@ import theano import theano.tensor as tt -from pymc3.memoize import WithMemoization, memoize from pymc3.theanof import floatX +from pymc3.util import WithMemoization, locally_cachedmethod from pymc3.variational.opvi import node_property from pymc3.variational.test_functions import rbf @@ -90,7 +90,6 @@ def logp_norm(self): ) return sized_symbolic_logp / self.approx.symbolic_normalizing_constant - @memoize - @theano.config.change_flags(compute_test_value="off") + @locally_cachedmethod def _kernel(self): return self._kernel_f(self.input_joint_matrix) diff --git a/requirements.txt b/requirements.txt index 93cb80ebc13..a103801473a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ arviz>=0.11.0 +cachetools>=4.2.1 dill fastprogress>=0.2.0 numpy>=1.15.0