Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: upgrade mypy and add type guards #16227

Merged
merged 3 commits into from
Aug 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [types-all]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the cleanest solution I could come up with to automatically populate stubs for mypy pre-commit hooks: pre-commit-ci/issues#69 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@villebro this is what we did at Airbnb as well. It's a real shame that Mypy doesn't install these by default if needed.

- repo: https://github.com/peterdemin/pip-compile-multi
rev: v2.4.1
hooks:
Expand Down
4 changes: 2 additions & 2 deletions RELEASING/changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,12 @@ def change_log(
with open(csv, "w") as csv_file:
log_items = list(logs)
field_names = log_items[0].keys()
writer = lib_csv.DictWriter(
writer = lib_csv.DictWriter( # type: ignore
csv_file,
delimiter=",",
quotechar='"',
quoting=lib_csv.QUOTE_ALL,
fieldnames=field_names,
fieldnames=field_names, # type: ignore
)
writer.writeheader()
for log in logs:
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ sqlparse==0.3.0
# via apache-superset
tabulate==0.8.9
# via apache-superset
typing-extensions==3.7.4.3
typing-extensions==3.10.0.0
# via
# aiohttp
# apache-superset
Expand Down
10 changes: 7 additions & 3 deletions scripts/benchmark_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ def import_migration_script(filepath: Path) -> ModuleType:
Import migration script as if it were a module.
"""
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
if spec:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
raise Exception(
"No module spec found in location: `{path}`".format(path=str(filepath))
)


def extract_modified_tables(module: ModuleType) -> Set[str]:
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def get_git_sha() -> str:
"simplejson>=3.15.0",
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
"sqlalchemy-utils>=0.36.6,<0.37",
"sqlalchemy-utils>=0.36.6, <0.37",
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
"tabulate==0.8.9",
"typing-extensions>=3.7.4.3,<4", # needed to support typing.Literal on py37
"typing-extensions>=3.10, <4", # needed to support Literal (3.8) and TypeGuard (3.10)
"wtforms-json",
],
extras_require={
"athena": ["pyathena>=1.10.8,<1.11"],
"athena": ["pyathena>=1.10.8, <1.11"],
"bigquery": [
"pandas_gbq>=0.10.0",
"pybigquery>=0.4.10",
Expand Down
2 changes: 1 addition & 1 deletion superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
WEBDRIVER_BASEURL_USER_FRIENDLY = config["WEBDRIVER_BASEURL_USER_FRIENDLY"]

ReportContent = namedtuple(
"EmailContent",
"ReportContent",
[
"body", # email body
"data", # attachments
Expand Down
4 changes: 2 additions & 2 deletions superset/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from flask import Flask
from flask_caching import Cache
from typing_extensions import TypedDict
from typing_extensions import Literal, TypedDict
from werkzeug.wrappers import Response

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +57,7 @@ class AdhocMetricColumn(TypedDict, total=False):
class AdhocMetric(TypedDict, total=False):
aggregate: str
column: Optional[AdhocMetricColumn]
expressionType: str
expressionType: Literal["SIMPLE", "SQL"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is screaming out to be an enum. Definitely a future TODO.

Copy link
Member Author

@villebro villebro Aug 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. In the next iteration I'll see if I can convert some of these Literal's to Enums. But it may well be that this is the most accurate representation of this type, as we want to indicate that this is a string that can only take on two values (equivalent of TypeScript's expressionType: 'SIMPLE' | 'SQL')

label: Optional[str]
sqlExpression: Optional[str]

Expand Down
4 changes: 2 additions & 2 deletions superset/utils/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class AsyncQueryManager:

def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis
self._redis: redis.Redis # type: ignore
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
Expand All @@ -100,7 +100,7 @@ def init_app(self, app: Flask) -> None:
"Please provide a JWT secret at least 32 bytes long"
)

self._redis = redis.Redis( # type: ignore
self._redis = redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
Expand Down
7 changes: 3 additions & 4 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict
from typing_extensions import TypedDict, TypeGuard

import _thread # pylint: disable=C0411
from superset.constants import (
Expand Down Expand Up @@ -1275,7 +1275,7 @@ def backend() -> str:
return get_example_database().backend


def is_adhoc_metric(metric: Metric) -> bool:
def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]:
return isinstance(metric, dict) and "expressionType" in metric


Expand All @@ -1288,7 +1288,6 @@ def get_metric_name(metric: Metric) -> str:
:raises ValueError: if metric object is invalid
"""
if is_adhoc_metric(metric):
metric = cast(AdhocMetric, metric)
label = metric.get("label")
if label:
return label
Expand All @@ -1306,7 +1305,7 @@ def get_metric_name(metric: Metric) -> str:
if column_name:
return column_name
raise ValueError(__("Invalid metric object"))
return cast(str, metric)
return metric # type: ignore
Comment on lines -1309 to +1308
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason mypy didn't pick up that Union[str, AdhocMetric] becomes str after we do is_adhoc_metric(metric) above



def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
Expand Down
7 changes: 7 additions & 0 deletions tests/unit_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GenericDataType,
get_metric_name,
get_metric_names,
is_adhoc_metric,
)

STR_METRIC = "my_metric"
Expand Down Expand Up @@ -91,3 +92,9 @@ def test_get_metric_names():
assert get_metric_names(
[STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC]
) == ["my_metric", "my SUM", "my_sql"]


def test_is_adhoc_metric():
assert is_adhoc_metric(STR_METRIC) is False
assert is_adhoc_metric(SIMPLE_SUM_ADHOC_METRIC) is True
assert is_adhoc_metric(SQL_ADHOC_METRIC) is True