Skip to content

Commit

Permalink
feat: simplify memoized_func (#19905)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Apr 29, 2022
1 parent 5f3e73c commit aff10a7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 10 deletions.
10 changes: 5 additions & 5 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def inspector(self) -> Inspector:
return sqla.inspect(engine)

@cache_util.memoized_func(
key=lambda self, *args, **kwargs: f"db:{self.id}:schema:None:table_list",
key="db:{self.id}:schema:None:table_list",
cache=cache_manager.data_cache,
)
def get_all_table_names_in_database( # pylint: disable=unused-argument
Expand All @@ -529,7 +529,7 @@ def get_all_table_names_in_database( # pylint: disable=unused-argument
return self.db_engine_spec.get_all_datasource_names(self, "table")

@cache_util.memoized_func(
key=lambda self, *args, **kwargs: f"db:{self.id}:schema:None:view_list",
key="db:{self.id}:schema:None:view_list",
cache=cache_manager.data_cache,
)
def get_all_view_names_in_database( # pylint: disable=unused-argument
Expand All @@ -544,7 +544,7 @@ def get_all_view_names_in_database( # pylint: disable=unused-argument
return self.db_engine_spec.get_all_datasource_names(self, "view")

@cache_util.memoized_func(
key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:table_list", # pylint: disable=line-too-long,useless-suppression
key="db:{self.id}:schema:{schema}:table_list",
cache=cache_manager.data_cache,
)
def get_all_table_names_in_schema( # pylint: disable=unused-argument
Expand Down Expand Up @@ -577,7 +577,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
return []

@cache_util.memoized_func(
key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:view_list", # pylint: disable=line-too-long,useless-suppression
key="db:{self.id}:schema:{schema}:view_list",
cache=cache_manager.data_cache,
)
def get_all_view_names_in_schema( # pylint: disable=unused-argument
Expand Down Expand Up @@ -608,7 +608,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
return []

@cache_util.memoized_func(
key=lambda self, *args, **kwargs: f"db:{self.id}:schema_list",
key="db:{self.id}:schema_list",
cache=cache_manager.data_cache,
)
def get_all_schema_names( # pylint: disable=unused-argument
Expand Down
19 changes: 14 additions & 5 deletions superset/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import inspect
import logging
from datetime import datetime, timedelta
from functools import wraps
Expand Down Expand Up @@ -94,7 +95,7 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-


def memoized_func(
key: Callable[..., str] = view_cache_key,
key: Optional[str] = None,
cache: Cache = cache_manager.cache,
) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg.
Expand All @@ -114,15 +115,23 @@ def memoized_func(
"""

def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
def wrapped_f(*args: Any, **kwargs: Any) -> Any:
if not kwargs.get("cache", True):
return f(self, *args, **kwargs)
return f(*args, **kwargs)

if key:
# format the key using args/kwargs passed to the decorated function
signature = inspect.signature(f)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
cache_key = key.format(**bound_args.arguments)
else:
cache_key = view_cache_key(*args, **kwargs)

cache_key = key(self, *args, **kwargs)
obj = cache.get(cache_key)
if not kwargs.get("force") and obj is not None:
return obj
obj = f(self, *args, **kwargs)
obj = f(*args, **kwargs)
cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout"))
return obj

Expand Down
52 changes: 52 additions & 0 deletions tests/unit_tests/utils/cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=import-outside-toplevel, unused-argument

from pytest_mock import MockerFixture


def test_memoized_func(app_context: None, mocker: MockerFixture) -> None:
"""
Test the ``memoized_func`` decorator.
"""
from superset.utils.cache import memoized_func

cache = mocker.MagicMock()

decorator = memoized_func("db:{self.id}:schema:{schema}:view_list", cache)
decorated = decorator(lambda self, schema, cache=False: 42)

self = mocker.MagicMock()
self.id = 1

# skip cache
result = decorated(self, "public", cache=False)
assert result == 42
cache.get.assert_not_called()

# check cache, no cached value
cache.get.return_value = None
result = decorated(self, "public", cache=True)
assert result == 42
cache.get.assert_called_with("db:1:schema:public:view_list")

# check cache, cached value
cache.get.return_value = 43
result = decorated(self, "public", cache=True)
assert result == 43

0 comments on commit aff10a7

Please sign in to comment.