Skip to content

Commit

Permalink
Pydantic 1.1 under python 3.8 does not allow custom datatypes with Ge…
Browse files Browse the repository at this point in the history
…nerics

Re-implemented as a functional_validator
  • Loading branch information
RishiDiwanTT committed Sep 14, 2023
1 parent ca85918 commit 2ee3fe2
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 33 deletions.
8 changes: 7 additions & 1 deletion api/admin/controller/quicksight.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def generate_quicksight_url(self, dashboard_id) -> Dict:
)

try:
delimiter = "|"
client = boto3.client("quicksight", region_name=region)
response = client.generate_embed_url_for_anonymous_user(
AwsAccountId=aws_account_id,
Expand All @@ -81,7 +82,12 @@ def generate_quicksight_url(self, dashboard_id) -> Dict:
ExperienceConfiguration={
"Dashboard": {"InitialDashboardId": dashboard_id}
},
SessionTags=[dict(Key="library_name", Value=l.name) for l in libraries],
SessionTags=[
dict(
Key="library_name",
Value=delimiter.join([l.name for l in libraries]),
)
],
)
except Exception as ex:
log.error(f"Error while fetching the Quisksight Embed url: {ex}")
Expand Down
12 changes: 9 additions & 3 deletions api/admin/model/quicksight.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from pydantic import Field
from typing import List

from core.util.flask_util import CustomBaseModel, StrCommaList
from pydantic import Field, validator

from core.util.flask_util import CustomBaseModel, str_comma_list_validator


class QuicksightGenerateUrlRequest(CustomBaseModel):
library_ids: StrCommaList[int] = Field(
library_ids: List[int] = Field(
description="The list of libraries to include in the dataset, an empty list is equivalent to all the libraries the user is allowed to access."
)

@validator("library_ids", pre=True)
def parse_library_ids(cls, value):
return str_comma_list_validator(value)


class QuicksightGenerateUrlResponse(CustomBaseModel):
embed_url: str = Field(description="The dashboard embed url.")
9 changes: 7 additions & 2 deletions api/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from api.admin.config import Configuration as AdminClientConfig
from api.admin.dashboard_stats import generate_statistics
from api.admin.model.dashboard_statistics import StatisticsResponse
from api.admin.model.quicksight import QuicksightGenerateUrlResponse
from api.admin.model.quicksight import (
QuicksightGenerateUrlRequest,
QuicksightGenerateUrlResponse,
)
from api.app import api_spec, app
from api.routes import allows_library, has_library, library_route
from core.app_server import ensure_pydantic_after_problem_detail, returns_problem_detail
Expand Down Expand Up @@ -347,7 +350,9 @@ def stats():

@app.route("/admin/quicksight_embed/<dashboard_id>")
@api_spec.validate(
resp=SpecResponse(HTTP_200=QuicksightGenerateUrlResponse), tags=["admin.quicksight"]
resp=SpecResponse(HTTP_200=QuicksightGenerateUrlResponse),
tags=["admin.quicksight"],
query=QuicksightGenerateUrlRequest,
)
@returns_json_or_response_or_problem_detail
@requires_admin
Expand Down
32 changes: 9 additions & 23 deletions core/util/flask_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for Flask applications."""
import datetime
import time
from typing import Any, Dict, Generic, TypeVar
from typing import Any, Dict
from wsgiref.handlers import format_date_time

from flask import Response as FlaskResponse
Expand Down Expand Up @@ -208,26 +208,12 @@ def api_dict(
return self.dict(*args, by_alias=by_alias, **kwargs)


T = TypeVar("T")
def str_comma_list_validator(value):
"""Validate a comma separated string and parse it into a list, generally used for query parameters"""
if isinstance(value, (int, float)):
# A single number shows up as an int
value = str(value)

Check warning on line 215 in core/util/flask_util.py

View check run for this annotation

Codecov / codecov/patch

core/util/flask_util.py#L215

Added line #L215 was not covered by tests
elif not isinstance(value, str):
raise TypeError("string required")

Check warning on line 217 in core/util/flask_util.py

View check run for this annotation

Codecov / codecov/patch

core/util/flask_util.py#L217

Added line #L217 was not covered by tests


class StrCommaList(list, Generic[T]):
"""A list of comma separated values, generally received as query parameters in a URL.
We expect pydantic to do the type coercion with respect to the Generic Type.
The final value is expected to be a List of the right Type.
Usage: StrCommaList[Type], just like a List[Type] defintion.
"""

@classmethod
def __get_validators__(cls):
"""Pydantic specific API"""
yield cls.validate

@classmethod
def validate(cls, comma_separated_str):
"""Validate the data type and split the string by commas"""
if not isinstance(comma_separated_str, str):
raise TypeError("String required")
# Pydantic wil typecast the values based on the Generic type to the List[...]
return [value for value in comma_separated_str.split(",")]
return value.split(",")
6 changes: 2 additions & 4 deletions tests/api/admin/controller/test_quicksight.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def test_generate_quicksight_url(
"Dashboard": {"InitialDashboardId": "uuid1"}
},
SessionTags=[
dict(Key="library_name", Value=name)
for name in [default.name, library1.name]
dict(Key="library_name", Value="|".join([default.name, library1.name])) # type: ignore[list-item]
],
)

Expand All @@ -92,8 +91,7 @@ def test_generate_quicksight_url(
"Dashboard": {"InitialDashboardId": "uuid2"}
},
SessionTags=[
dict(Key="library_name", Value=name)
for name in [library1.name] # Only the Admin authorized library
dict(Key="library_name", Value="|".join([library1.name])) # type: ignore[list-item]
],
)

Expand Down

0 comments on commit 2ee3fe2

Please sign in to comment.