From 0996cbf7b480090327a47bc0246a936c5b18e83e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 29 Apr 2022 14:26:44 -0700 Subject: [PATCH] feat: simplify memoized_func --- superset/models/core.py | 10 +++--- superset/utils/cache.py | 19 +++++++--- tests/unit_tests/utils/cache_test.py | 52 ++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 10 deletions(-) create mode 100644 tests/unit_tests/utils/cache_test.py diff --git a/superset/models/core.py b/superset/models/core.py index c2052749ad8a0..d90aa2569625c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -514,7 +514,7 @@ def inspector(self) -> Inspector: return sqla.inspect(engine) @cache_util.memoized_func( - key=lambda self, *args, **kwargs: f"db:{self.id}:schema:None:table_list", + key="db:{self.id}:schema:None:table_list", cache=cache_manager.data_cache, ) def get_all_table_names_in_database( # pylint: disable=unused-argument @@ -529,7 +529,7 @@ def get_all_table_names_in_database( # pylint: disable=unused-argument return self.db_engine_spec.get_all_datasource_names(self, "table") @cache_util.memoized_func( - key=lambda self, *args, **kwargs: f"db:{self.id}:schema:None:view_list", + key="db:{self.id}:schema:None:view_list", cache=cache_manager.data_cache, ) def get_all_view_names_in_database( # pylint: disable=unused-argument @@ -544,7 +544,7 @@ def get_all_view_names_in_database( # pylint: disable=unused-argument return self.db_engine_spec.get_all_datasource_names(self, "view") @cache_util.memoized_func( - key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:table_list", # pylint: disable=line-too-long,useless-suppression + key="db:{self.id}:schema:{schema}:table_list", cache=cache_manager.data_cache, ) def get_all_table_names_in_schema( # pylint: disable=unused-argument @@ -577,7 +577,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument return [] @cache_util.memoized_func( - key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:view_list", # pylint: disable=line-too-long,useless-suppression + key="db:{self.id}:schema:{schema}:view_list", cache=cache_manager.data_cache, ) def get_all_view_names_in_schema( # pylint: disable=unused-argument @@ -608,7 +608,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument return [] @cache_util.memoized_func( - key=lambda self, *args, **kwargs: f"db:{self.id}:schema_list", + key="db:{self.id}:schema_list", cache=cache_manager.data_cache, ) def get_all_schema_names( # pylint: disable=unused-argument diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 02a4cdfecc0ee..d86f92398b570 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import inspect import logging from datetime import datetime, timedelta from functools import wraps @@ -94,7 +95,7 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused- def memoized_func( - key: Callable[..., str] = view_cache_key, + key: Optional[str] = None, cache: Cache = cache_manager.cache, ) -> Callable[..., Any]: """Use this decorator to cache functions that have predefined first arg. @@ -114,15 +115,23 @@ def memoized_func( """ def wrap(f: Callable[..., Any]) -> Callable[..., Any]: - def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any: + def wrapped_f(*args: Any, **kwargs: Any) -> Any: if not kwargs.get("cache", True): - return f(self, *args, **kwargs) + return f(*args, **kwargs) + + if key: + # format the key using args/kwargs passed to the decorated function + signature = inspect.signature(f) + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + cache_key = key.format(**bound_args.arguments) + else: + cache_key = view_cache_key(*args, **kwargs) - cache_key = key(self, *args, **kwargs) obj = cache.get(cache_key) if not kwargs.get("force") and obj is not None: return obj - obj = f(self, *args, **kwargs) + obj = f(*args, **kwargs) cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout")) return obj diff --git a/tests/unit_tests/utils/cache_test.py b/tests/unit_tests/utils/cache_test.py new file mode 100644 index 0000000000000..7c1354aa3cb39 --- /dev/null +++ b/tests/unit_tests/utils/cache_test.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, unused-argument + +from pytest_mock import MockerFixture + + +def test_memoized_func(app_context: None, mocker: MockerFixture) -> None: + """ + Test the ``memoized_func`` decorator. + """ + from superset.utils.cache import memoized_func + + cache = mocker.MagicMock() + + decorator = memoized_func("db:{self.id}:schema:{schema}:view_list", cache) + decorated = decorator(lambda self, schema, cache=False: 42) + + self = mocker.MagicMock() + self.id = 1 + + # skip cache + result = decorated(self, "public", cache=False) + assert result == 42 + cache.get.assert_not_called() + + # check cache, no cached value + cache.get.return_value = None + result = decorated(self, "public", cache=True) + assert result == 42 + cache.get.assert_called_with("db:1:schema:public:view_list") + + # check cache, cached value + cache.get.return_value = 43 + result = decorated(self, "public", cache=True) + assert result == 43