diff --git a/api/admin/controller/quicksight.py b/api/admin/controller/quicksight.py index 4cd1b2505e..4ebb9d5f9a 100644 --- a/api/admin/controller/quicksight.py +++ b/api/admin/controller/quicksight.py @@ -73,6 +73,7 @@ def generate_quicksight_url(self, dashboard_id) -> Dict: ) try: + delimiter = "|" client = boto3.client("quicksight", region_name=region) response = client.generate_embed_url_for_anonymous_user( AwsAccountId=aws_account_id, @@ -81,7 +82,12 @@ def generate_quicksight_url(self, dashboard_id) -> Dict: ExperienceConfiguration={ "Dashboard": {"InitialDashboardId": dashboard_id} }, - SessionTags=[dict(Key="library_name", Value=l.name) for l in libraries], + SessionTags=[ + dict( + Key="library_name", + Value=delimiter.join([l.name for l in libraries]), + ) + ], ) except Exception as ex: log.error(f"Error while fetching the Quisksight Embed url: {ex}") diff --git a/api/admin/model/quicksight.py b/api/admin/model/quicksight.py index eecacabee6..de6533e317 100644 --- a/api/admin/model/quicksight.py +++ b/api/admin/model/quicksight.py @@ -1,13 +1,19 @@ -from pydantic import Field +from typing import List -from core.util.flask_util import CustomBaseModel, StrCommaList +from pydantic import Field, validator + +from core.util.flask_util import CustomBaseModel, str_comma_list_validator class QuicksightGenerateUrlRequest(CustomBaseModel): - library_ids: StrCommaList[int] = Field( + library_ids: List[int] = Field( description="The list of libraries to include in the dataset, an empty list is equivalent to all the libraries the user is allowed to access." ) + @validator("library_ids", pre=True) + def parse_library_ids(cls, value): + return str_comma_list_validator(value) + class QuicksightGenerateUrlResponse(CustomBaseModel): embed_url: str = Field(description="The dashboard embed url.") diff --git a/api/admin/routes.py b/api/admin/routes.py index 14d4bc890e..21c1ae05c8 100644 --- a/api/admin/routes.py +++ b/api/admin/routes.py @@ -10,7 +10,10 @@ from api.admin.config import Configuration as AdminClientConfig from api.admin.dashboard_stats import generate_statistics from api.admin.model.dashboard_statistics import StatisticsResponse -from api.admin.model.quicksight import QuicksightGenerateUrlResponse +from api.admin.model.quicksight import ( + QuicksightGenerateUrlRequest, + QuicksightGenerateUrlResponse, +) from api.app import api_spec, app from api.routes import allows_library, has_library, library_route from core.app_server import ensure_pydantic_after_problem_detail, returns_problem_detail @@ -347,7 +350,9 @@ def stats(): @app.route("/admin/quicksight_embed/") @api_spec.validate( - resp=SpecResponse(HTTP_200=QuicksightGenerateUrlResponse), tags=["admin.quicksight"] + resp=SpecResponse(HTTP_200=QuicksightGenerateUrlResponse), + tags=["admin.quicksight"], + query=QuicksightGenerateUrlRequest, ) @returns_json_or_response_or_problem_detail @requires_admin diff --git a/core/util/flask_util.py b/core/util/flask_util.py index a7634b7ede..1198ea51da 100644 --- a/core/util/flask_util.py +++ b/core/util/flask_util.py @@ -1,7 +1,7 @@ """Utilities for Flask applications.""" import datetime import time -from typing import Any, Dict, Generic, TypeVar +from typing import Any, Dict from wsgiref.handlers import format_date_time from flask import Response as FlaskResponse @@ -208,26 +208,12 @@ def api_dict( return self.dict(*args, by_alias=by_alias, **kwargs) -T = TypeVar("T") +def str_comma_list_validator(value): + """Validate a comma separated string and parse it into a list, generally used for query parameters""" + if isinstance(value, (int, float)): + # A single number shows up as an int + value = str(value) + elif not isinstance(value, str): + raise TypeError("string required") - -class StrCommaList(list, Generic[T]): - """A list of comma separated values, generally received as query parameters in a URL. - We expect pydantic to do the type coercion with respect to the Generic Type. - The final value is expected to be a List of the right Type. - - Usage: StrCommaList[Type], just like a List[Type] defintion. - """ - - @classmethod - def __get_validators__(cls): - """Pydantic specific API""" - yield cls.validate - - @classmethod - def validate(cls, comma_separated_str): - """Validate the data type and split the string by commas""" - if not isinstance(comma_separated_str, str): - raise TypeError("String required") - # Pydantic wil typecast the values based on the Generic type to the List[...] - return [value for value in comma_separated_str.split(",")] + return value.split(",") diff --git a/tests/api/admin/controller/test_quicksight.py b/tests/api/admin/controller/test_quicksight.py index 29e567ad6c..2078d11a90 100644 --- a/tests/api/admin/controller/test_quicksight.py +++ b/tests/api/admin/controller/test_quicksight.py @@ -68,8 +68,7 @@ def test_generate_quicksight_url( "Dashboard": {"InitialDashboardId": "uuid1"} }, SessionTags=[ - dict(Key="library_name", Value=name) - for name in [default.name, library1.name] + dict(Key="library_name", Value="|".join([default.name, library1.name])) # type: ignore[list-item] ], ) @@ -92,8 +91,7 @@ def test_generate_quicksight_url( "Dashboard": {"InitialDashboardId": "uuid2"} }, SessionTags=[ - dict(Key="library_name", Value=name) - for name in [library1.name] # Only the Admin authorized library + dict(Key="library_name", Value="|".join([library1.name])) # type: ignore[list-item] ], )