From e87a3c15dd03a8466c9114ac700d02f4011bcc2f Mon Sep 17 00:00:00 2001 From: Craig Date: Thu, 21 Sep 2023 10:34:40 -0700 Subject: [PATCH 01/10] First pass at refactoring async query job submission --- superset/async_events/api.py | 6 +-- superset/async_events/async_query_manager.py | 40 ++++++++++++++++++- .../data/commands/create_async_job_command.py | 11 ++--- superset/config.py | 1 + superset/views/core.py | 11 ++--- 5 files changed, 53 insertions(+), 16 deletions(-) diff --git a/superset/async_events/api.py b/superset/async_events/api.py index 0a6ceb9c5f4b1..376671cf364a7 100644 --- a/superset/async_events/api.py +++ b/superset/async_events/api.py @@ -88,9 +88,9 @@ def events(self) -> Response: $ref: '#/components/responses/500' """ try: - async_channel_id = async_query_manager.parse_jwt_from_request(request)[ - "channel" - ] + async_channel_id = async_query_manager.parse_channel_id_from_request( + request + ) last_event_id = request.args.get("last_id") events = async_query_manager.read_events(async_channel_id, last_event_id) diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index d67d9ca0817ec..14868409cfc56 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -82,6 +82,8 @@ def __init__(self) -> None: self._jwt_cookie_domain: Optional[str] self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None self._jwt_secret: str + self._load_chart_data_into_cache_job: Any = None + self._load_explore_json_into_cache_job: Any = None def init_app(self, app: Flask) -> None: config = app.config @@ -115,6 +117,18 @@ def init_app(self, app: Flask) -> None: self._jwt_cookie_domain = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"] self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] + if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]: + self.register_request_handlers(app) + + from superset.tasks.async_queries import ( + load_chart_data_into_cache, + load_explore_json_into_cache, + ) + + self._load_chart_data_into_cache_job = load_chart_data_into_cache + self._load_explore_json_into_cache_job = load_explore_json_into_cache + + def register_request_handlers(self, app: Flask) -> None: @app.after_request def validate_session(response: Response) -> Response: user_id = get_user_id() @@ -149,7 +163,7 @@ def validate_session(response: Response) -> Response: return response - def parse_jwt_from_request(self, req: Request) -> dict[str, Any]: + def parse_channel_id_from_request(self, req: Request) -> str: token = req.cookies.get(self._jwt_cookie_name) if not token: raise AsyncQueryTokenException("Token not preset") @@ -166,6 +180,30 @@ def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]: channel_id, job_id, user_id, status=self.STATUS_PENDING ) + def submit_explore_json_job( + self, + channel_id: str, + form_data: dict[str, Any], + response_type: str, + force: Optional[bool] = False, + user_id: Optional[int] = None, + ) -> dict[str, Any]: + job_metadata = self.init_job(channel_id, user_id) + self._load_explore_json_into_cache_job.delay( + job_metadata, + form_data, + response_type, + force, + ) + return job_metadata + + def submit_chart_data_job( + self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int] + ) -> dict[str, Any]: + job_metadata = self.init_job(channel_id, user_id) + self._load_chart_data_into_cache_job.delay(job_metadata, form_data) + return job_metadata + def read_events( self, channel: str, last_id: Optional[str] ) -> list[Optional[dict[str, Any]]]: diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py index fb6e3f3dbff34..82d23d2ad7bb2 100644 --- a/superset/charts/data/commands/create_async_job_command.py +++ b/superset/charts/data/commands/create_async_job_command.py @@ -29,10 +29,11 @@ class CreateAsyncChartDataJobCommand: _async_channel_id: str def validate(self, request: Request) -> None: - jwt_data = async_query_manager.parse_jwt_from_request(request) - self._async_channel_id = jwt_data["channel"] + self._async_channel_id = async_query_manager.parse_channel_id_from_request( + request + ) def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]: - job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) - load_chart_data_into_cache.delay(job_metadata, form_data) - return job_metadata + return async_query_manager.submit_chart_data_job( + self._async_channel_id, form_data, user_id + ) diff --git a/superset/config.py b/superset/config.py index 74f5df0e6e200..ed86c0f26cf7e 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1516,6 +1516,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-" GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000 GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000 +GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS = True GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token" GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | ( diff --git a/superset/views/core.py b/superset/views/core.py index 268c6fe333d74..7b749a2341735 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -320,14 +320,11 @@ def explore_json( # at which point they will call the /explore_json/data/ # endpoint to retrieve the results. try: - async_channel_id = async_query_manager.parse_jwt_from_request( - request - )["channel"] - job_metadata = async_query_manager.init_job( - async_channel_id, get_user_id() + async_channel_id = ( + async_query_manager.parse_channel_id_from_request(request) ) - load_explore_json_into_cache.delay( - job_metadata, form_data, response_type, force + job_metadata = async_query_manager.submit_explore_json_job( + async_channel_id, form_data, response_type, force, get_user_id() ) except AsyncQueryTokenException: return json_error_response("Not authorized", 401) From a6ac45365e246dad9bed843f3bf9c6644f217911 Mon Sep 17 00:00:00 2001 From: Craig Date: Fri, 29 Sep 2023 11:39:44 -0700 Subject: [PATCH 02/10] Adding async query unit tests --- superset/async_events/async_query_manager.py | 2 +- tests/unit_tests/async_events/__init__.py | 16 +++++ .../async_events/async_query_manager_tests.py | 67 +++++++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/async_events/__init__.py create mode 100644 tests/unit_tests/async_events/async_query_manager_tests.py diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index 14868409cfc56..24ffdd7badc45 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -169,7 +169,7 @@ def parse_channel_id_from_request(self, req: Request) -> str: raise AsyncQueryTokenException("Token not preset") try: - return jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])["channel"] except Exception as ex: logger.warning("Parse jwt failed", exc_info=True) raise AsyncQueryTokenException("Failed to parse token") from ex diff --git a/tests/unit_tests/async_events/__init__.py b/tests/unit_tests/async_events/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/async_events/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/async_events/async_query_manager_tests.py b/tests/unit_tests/async_events/async_query_manager_tests.py new file mode 100644 index 0000000000000..b4ae06dfc3f6f --- /dev/null +++ b/tests/unit_tests/async_events/async_query_manager_tests.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import Mock + +from jwt import encode +from pytest import fixture, raises + +from superset.async_events.async_query_manager import ( + AsyncQueryManager, + AsyncQueryTokenException, +) + +JWT_TOKEN_SECRET = "some_secret" +JWT_TOKEN_COOKIE_NAME = "superset_async_jwt" + + +@fixture +def async_query_manager(): + query_manager = AsyncQueryManager() + query_manager._jwt_secret = JWT_TOKEN_SECRET + query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME + + return query_manager + + +def test_parse_channel_id_from_request(async_query_manager): + encoded_token = encode( + {"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256" + ) + + request = Mock() + request.cookies = {"superset_async_jwt": encoded_token} + + assert ( + async_query_manager.parse_channel_id_from_request(request) == "test_channel_id" + ) + + +def test_parse_channel_id_from_request_no_cookie(async_query_manager): + request = Mock() + request.cookies = {} + + with raises(AsyncQueryTokenException): + async_query_manager.parse_channel_id_from_request(request) + + +def test_parse_channel_id_from_request_bad_jwt(async_query_manager): + request = Mock() + request.cookies = {"superset_async_jwt": "bad_jwt"} + + with raises(AsyncQueryTokenException): + async_query_manager.parse_channel_id_from_request(request) From 07ee26e2feffb716d9a68b221d667f7d8c4be157 Mon Sep 17 00:00:00 2001 From: Craig Date: Fri, 29 Sep 2023 11:55:32 -0700 Subject: [PATCH 03/10] Linting --- superset/async_events/async_query_manager.py | 3 +++ superset/charts/data/commands/create_async_job_command.py | 1 - superset/cli/lib.py | 1 - superset/views/core.py | 1 - 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index 24ffdd7badc45..94941541fb4f9 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -83,6 +83,7 @@ def __init__(self) -> None: self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None self._jwt_secret: str self._load_chart_data_into_cache_job: Any = None + # pylint: disable=invalid-name self._load_explore_json_into_cache_job: Any = None def init_app(self, app: Flask) -> None: @@ -120,6 +121,7 @@ def init_app(self, app: Flask) -> None: if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]: self.register_request_handlers(app) + # pylint: disable=import-outside-toplevel from superset.tasks.async_queries import ( load_chart_data_into_cache, load_explore_json_into_cache, @@ -180,6 +182,7 @@ def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]: channel_id, job_id, user_id, status=self.STATUS_PENDING ) + # pylint: disable=too-many-arguments def submit_explore_json_job( self, channel_id: str, diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py index 82d23d2ad7bb2..da126277ee7da 100644 --- a/superset/charts/data/commands/create_async_job_command.py +++ b/superset/charts/data/commands/create_async_job_command.py @@ -20,7 +20,6 @@ from flask import Request from superset.extensions import async_query_manager -from superset.tasks.async_queries import load_chart_data_into_cache logger = logging.getLogger(__name__) diff --git a/superset/cli/lib.py b/superset/cli/lib.py index 9e14ab6aae025..843acbc92b361 100755 --- a/superset/cli/lib.py +++ b/superset/cli/lib.py @@ -26,7 +26,6 @@ feature_flags.update(config.FEATURE_FLAGS) feature_flags_func = config.GET_FEATURE_FLAGS_FUNC if feature_flags_func: - # pylint: disable=not-callable try: feature_flags = feature_flags_func(feature_flags) except Exception: # pylint: disable=broad-except diff --git a/superset/views/core.py b/superset/views/core.py index 7b749a2341735..e67a255da2850 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -74,7 +74,6 @@ from superset.models.user_attributes import UserAttribute from superset.sqllab.utils import bootstrap_sqllab_data from superset.superset_typing import FlaskResponse -from superset.tasks.async_queries import load_explore_json_into_cache from superset.utils import core as utils from superset.utils.cache import etag_cache from superset.utils.core import ( From 71035274f9135e18cb3b4749e2f85916cba68847 Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 09:05:03 -0700 Subject: [PATCH 04/10] Linting stuff I didn't change --- superset/cli/lib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/superset/cli/lib.py b/superset/cli/lib.py index 843acbc92b361..00bfd7695a8fc 100755 --- a/superset/cli/lib.py +++ b/superset/cli/lib.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Callable, Any from superset import config @@ -24,7 +25,7 @@ feature_flags = config.DEFAULT_FEATURE_FLAGS.copy() feature_flags.update(config.FEATURE_FLAGS) -feature_flags_func = config.GET_FEATURE_FLAGS_FUNC +feature_flags_func: Callable[[Any], Any] = config.GET_FEATURE_FLAGS_FUNC if feature_flags_func: try: feature_flags = feature_flags_func(feature_flags) From 72fc19033679d213287b9051e05911738dba475a Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 09:37:12 -0700 Subject: [PATCH 05/10] Linting stuff I didn't change again --- superset/cli/lib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/cli/lib.py b/superset/cli/lib.py index 00bfd7695a8fc..68f6f0383188a 100755 --- a/superset/cli/lib.py +++ b/superset/cli/lib.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Callable, Any from superset import config @@ -25,9 +24,10 @@ feature_flags = config.DEFAULT_FEATURE_FLAGS.copy() feature_flags.update(config.FEATURE_FLAGS) -feature_flags_func: Callable[[Any], Any] = config.GET_FEATURE_FLAGS_FUNC +feature_flags_func = config.GET_FEATURE_FLAGS_FUNC if feature_flags_func: try: + # pylint: disable=not-callable feature_flags = feature_flags_func(feature_flags) except Exception: # pylint: disable=broad-except # bypass any feature flags that depend on context From d040d2a5b88a39896d555544f43a0c350c3d5904 Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 10:24:32 -0700 Subject: [PATCH 06/10] Fixing test --- .../tasks/async_queries_tests.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 50806ee677394..b7fe7b066d962 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -20,33 +20,26 @@ import pytest from celery.exceptions import SoftTimeLimitExceeded -from flask import g from superset.charts.commands.exceptions import ChartDataQueryFailedError from superset.charts.data.commands.get_data_command import ChartDataCommand from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager -from superset.tasks import async_queries -from superset.tasks.async_queries import ( - load_chart_data_into_cache, - load_explore_json_into_cache, -) -from superset.utils.core import get_user_id from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, - load_birth_names_data, ) from tests.integration_tests.fixtures.query_context import get_query_context -from tests.integration_tests.fixtures.tags import with_tagging_system_feature from tests.integration_tests.test_app import app class TestAsyncQueries(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") - @mock.patch.object(async_queries, "set_form_data") + @mock.patch("superset.tasks.async_queries.set_form_data") def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): + from superset.tasks.async_queries import load_chart_data_into_cache + app._got_first_request = False async_query_manager.init_app(app) query_context = get_query_context("birth_names") @@ -70,6 +63,8 @@ def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): ) @mock.patch.object(async_query_manager, "update_job") def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command): + from superset.tasks.async_queries import load_chart_data_into_cache + app._got_first_request = False async_query_manager.init_app(app) query_context = get_query_context("birth_names") @@ -93,6 +88,8 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman def test_soft_timeout_load_chart_data_into_cache( self, mock_update_job, mock_run_command ): + from superset.tasks.async_queries import load_chart_data_into_cache + app._got_first_request = False async_query_manager.init_app(app) user = security_manager.find_user("gamma") @@ -107,9 +104,8 @@ def test_soft_timeout_load_chart_data_into_cache( errors = ["A timeout occurred while loading chart data"] with pytest.raises(SoftTimeLimitExceeded): - with mock.patch.object( - async_queries, - "set_form_data", + with mock.patch( + "superset.tasks.async_queries.set_form_data" ) as set_form_data: set_form_data.side_effect = SoftTimeLimitExceeded() load_chart_data_into_cache(job_metadata, form_data) @@ -118,6 +114,8 @@ def test_soft_timeout_load_chart_data_into_cache( @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache(self, mock_update_job): + from superset.tasks.async_queries import load_explore_json_into_cache + app._got_first_request = False async_query_manager.init_app(app) table = self.get_table(name="birth_names") @@ -146,10 +144,12 @@ def test_load_explore_json_into_cache(self, mock_update_job): ) @mock.patch.object(async_query_manager, "update_job") - @mock.patch.object(async_queries, "set_form_data") + @mock.patch("superset.tasks.async_queries.set_form_data") def test_load_explore_json_into_cache_error( self, mock_set_form_data, mock_update_job ): + from superset.tasks.async_queries import load_explore_json_into_cache + app._got_first_request = False async_query_manager.init_app(app) user = security_manager.find_user("gamma") @@ -174,6 +174,8 @@ def test_load_explore_json_into_cache_error( def test_soft_timeout_load_explore_json_into_cache( self, mock_update_job, mock_run_command ): + from superset.tasks.async_queries import load_explore_json_into_cache + app._got_first_request = False async_query_manager.init_app(app) user = security_manager.find_user("gamma") @@ -188,9 +190,8 @@ def test_soft_timeout_load_explore_json_into_cache( errors = ["A timeout occurred while loading explore json, error"] with pytest.raises(SoftTimeLimitExceeded): - with mock.patch.object( - async_queries, - "set_form_data", + with mock.patch( + "superset.tasks.async_queries.set_form_data" ) as set_form_data: set_form_data.side_effect = SoftTimeLimitExceeded() load_explore_json_into_cache(job_metadata, form_data) From 3f40ab243830707e1bb269d5a9ee080ced17e8e0 Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 10:41:55 -0700 Subject: [PATCH 07/10] Fixing test I didn't touch --- tests/unit_tests/db_engine_specs/test_clickhouse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index 6dfeddaf37cfd..4dc583c4fc957 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -16,6 +16,7 @@ # under the License. from datetime import datetime +from http.client import HTTPConnection from typing import Any, Optional from unittest.mock import Mock @@ -63,7 +64,7 @@ def test_execute_connection_error() -> None: cursor = Mock() cursor.execute.side_effect = NewConnectionError( - "Dummypool", "Exception with sensitive data" + HTTPConnection("localhost", 8080), "Exception with sensitive data" ) with pytest.raises(SupersetDBAPIDatabaseError) as ex: ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1") From 7765f51e2c7074b50a9443143e7a89c3accc3f3a Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 10:52:33 -0700 Subject: [PATCH 08/10] Fixing import --- tests/unit_tests/db_engine_specs/test_clickhouse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index 4dc583c4fc957..d5673a2e65755 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -16,7 +16,6 @@ # under the License. from datetime import datetime -from http.client import HTTPConnection from typing import Any, Optional from unittest.mock import Mock @@ -31,6 +30,7 @@ String, TypeEngine, ) +from urllib3.connection import HTTPConnection from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( From 7343900ef8bccdf22deeb0609a5d49c8ed931db2 Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 13:04:43 -0700 Subject: [PATCH 09/10] Fixing imports in test / fixtures --- tests/integration_tests/tasks/async_queries_tests.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index b7fe7b066d962..8e6e595757c4f 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -28,13 +28,17 @@ from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, + load_birth_names_data, ) from tests.integration_tests.fixtures.query_context import get_query_context +from tests.integration_tests.fixtures.tags import with_tagging_system_feature from tests.integration_tests.test_app import app class TestAsyncQueries(SupersetTestCase): - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @pytest.mark.usefixtures( + "load_birth_names_data", "load_birth_names_dashboard_with_slices" + ) @mock.patch.object(async_query_manager, "update_job") @mock.patch("superset.tasks.async_queries.set_form_data") def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): From fd74c05748d9068c8524529989fcecfe4b347431 Mon Sep 17 00:00:00 2001 From: Craig Date: Mon, 2 Oct 2023 14:02:50 -0700 Subject: [PATCH 10/10] Making JWT channelID configurable --- superset-websocket/src/config.ts | 2 ++ superset-websocket/src/index.ts | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/superset-websocket/src/config.ts b/superset-websocket/src/config.ts index 5d2642b4e9ac6..7d0fac323e975 100644 --- a/superset-websocket/src/config.ts +++ b/superset-websocket/src/config.ts @@ -38,6 +38,7 @@ type ConfigType = { redisStreamReadBlockMs: number; jwtSecret: string; jwtCookieName: string; + jwtChannelIdKey: string; socketResponseTimeoutMs: number; pingSocketsIntervalMs: number; gcChannelsIntervalMs: number; @@ -54,6 +55,7 @@ function defaultConfig(): ConfigType { redisStreamReadBlockMs: 5000, jwtSecret: '', jwtCookieName: 'async-token', + jwtChannelIdKey: 'channel', socketResponseTimeoutMs: 60 * 1000, pingSocketsIntervalMs: 20 * 1000, gcChannelsIntervalMs: 120 * 1000, diff --git a/superset-websocket/src/index.ts b/superset-websocket/src/index.ts index ecb20a4458c09..782275e5ca53a 100644 --- a/superset-websocket/src/index.ts +++ b/superset-websocket/src/index.ts @@ -53,7 +53,7 @@ interface EventValue { result_url?: string; } interface JwtPayload { - channel: string; + [key: string]: string; } interface FetchRangeFromStreamParams { sessionId: string; @@ -253,14 +253,20 @@ export const processStreamResults = (results: StreamResult[]): void => { /** * Verify and parse a JWT cookie from an HTTP request. - * Returns the JWT payload or throws an error on invalid token. + * Returns the channelId from the JWT payload found in the cookie + * configured via 'jwtCookieName' in the config. */ -const getJwtPayload = (request: http.IncomingMessage): JwtPayload => { +const readChannelId = (request: http.IncomingMessage): string => { const cookies = cookie.parse(request.headers.cookie || ''); const token = cookies[opts.jwtCookieName]; if (!token) throw new Error('JWT not present'); - return jwt.verify(token, opts.jwtSecret) as JwtPayload; + const jwtPayload = jwt.verify(token, opts.jwtSecret) as JwtPayload; + const channelId = jwtPayload[opts.jwtChannelIdKey]; + + if (!channelId) throw new Error('Channel ID not present in JWT'); + + return channelId; }; /** @@ -286,8 +292,7 @@ export const incrementId = (id: string): string => { * WebSocket `connection` event handler, called via wss */ export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => { - const jwtPayload: JwtPayload = getJwtPayload(request); - const channel: string = jwtPayload.channel; + const channel: string = readChannelId(request); const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() }; // add this ws instance to the internal registry @@ -351,8 +356,7 @@ export const httpUpgrade = ( head: Buffer, ) => { try { - const jwtPayload: JwtPayload = getJwtPayload(request); - if (!jwtPayload.channel) throw new Error('Channel ID not present'); + readChannelId(request); } catch (err) { // JWT invalid, do not establish a WebSocket connection logger.error(err);