Skip to content

Commit

Permalink
upgrade mypy and fix all errors
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoop committed Jun 14, 2023
1 parent 58b9760 commit 9269e79
Show file tree
Hide file tree
Showing 17 changed files with 43 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ repos:
alias: flake8-check
stages: [manual]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.942
rev: v1.3.0
hooks:
- id: mypy
# N.B.: Mypy is... a bit fragile.
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
return info_schema_name_map

def _relations_cache_for_schemas(
self, manifest: Manifest, cache_schemas: Set[BaseRelation] = None
self, manifest: Manifest, cache_schemas: Optional[Set[BaseRelation]] = None
) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
Expand Down Expand Up @@ -417,7 +417,7 @@ def set_relations_cache(
self,
manifest: Manifest,
clear: bool = False,
required_schemas: Set[BaseRelation] = None,
required_schemas: Optional[Set[BaseRelation]] = None,
) -> None:
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
Expand Down Expand Up @@ -940,7 +940,7 @@ def execute_macro(
manifest: Optional[Manifest] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any] = None,
kwargs: Optional[Dict[str, Any]] = None,
text_only_columns: Optional[Iterable[str]] = None,
) -> AttrDict:
"""Look macro_name up in the manifest and execute its results.
Expand Down
6 changes: 4 additions & 2 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Set, List
from typing import Set, List, Optional

from click import Context, get_current_context, BadOptionUsage
from click.core import ParameterSource, Command, Group
Expand Down Expand Up @@ -78,7 +78,9 @@ def args_to_context(args: List[str]) -> Context:

@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:
def __init__(
self, ctx: Optional[Context] = None, user_config: Optional[UserConfig] = None
) -> None:

# set the default flags
for key, value in FLAGS_DEFAULTS.items():
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class dbtInternalException(Exception):
# Programmatic invocation
class dbtRunner:
def __init__(
self, project: Project = None, profile: Profile = None, manifest: Manifest = None
self,
project: Optional[Project] = None,
profile: Optional[Profile] = None,
manifest: Optional[Manifest] = None,
):
self.project = project
self.profile = profile
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def var(self) -> ConfiguredVar:


def generate_schema_yml_context(
config: AdapterRequiredConfig, project_name: str, schema_yaml_vars: SchemaYamlVars = None
config: AdapterRequiredConfig,
project_name: str,
schema_yaml_vars: Optional[SchemaYamlVars] = None,
) -> Dict[str, Any]:
ctx = SchemaYamlContext(config, project_name, schema_yaml_vars)
return ctx.to_dict()
Expand Down
16 changes: 10 additions & 6 deletions core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Iterator, Dict, Any, TypeVar, Generic
from typing import List, Iterator, Dict, Any, TypeVar, Generic, Optional

from dbt.config import RuntimeConfig, Project, IsFQNResource
from dbt.contracts.graph.model_config import BaseConfig, get_config_for, _listify
Expand Down Expand Up @@ -130,7 +130,7 @@ def calculate_node_config(
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: Dict[str, Any] = None,
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> BaseConfig:
own_config = self.get_node_project(project_name)

Expand Down Expand Up @@ -166,7 +166,7 @@ def calculate_node_config_dict(
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: Dict[str, Any],
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
...

Expand Down Expand Up @@ -200,7 +200,7 @@ def calculate_node_config_dict(
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: dict = None,
patch_config_dict: Optional[dict] = None,
) -> Dict[str, Any]:
config = self.calculate_node_config(
config_call_dict=config_call_dict,
Expand All @@ -225,7 +225,7 @@ def calculate_node_config_dict(
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: dict = None,
patch_config_dict: Optional[dict] = None,
) -> Dict[str, Any]:
# TODO CT-211
return self.calculate_node_config(
Expand Down Expand Up @@ -318,7 +318,11 @@ def _add_config_call(cls, config_call_dict, opts: Dict[str, Any]) -> None:
config_call_dict[k] = v

def build_config_dict(
self, base: bool = False, *, rendered: bool = True, patch_config_dict: dict = None
self,
base: bool = False,
*,
rendered: bool = True,
patch_config_dict: Optional[dict] = None,
) -> Dict[str, Any]:
if rendered:
# TODO CT-211
Expand Down
1 change: 1 addition & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def __call__(self, *args: str) -> RelationProxy:


class BaseMetricResolver(BaseResolver):
@abc.abstractmethod
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
...

Expand Down
3 changes: 2 additions & 1 deletion core/dbt/events/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
import dbt.events.proto_types as pt
import sys
from typing import Optional

if sys.version_info >= (3, 8):
from typing import Protocol
Expand Down Expand Up @@ -95,7 +96,7 @@ class EventMsg(Protocol):
data: BaseEvent


def msg_from_base_event(event: BaseEvent, level: EventLevel = None):
def msg_from_base_event(event: BaseEvent, level: Optional[EventLevel] = None):

msg_class_name = f"{type(event).__name__}Msg"
msg_cls = getattr(pt, msg_class_name)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/events/eventmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(self) -> None:
self.callbacks: List[Callable[[EventMsg], None]] = []
self.invocation_id: str = str(uuid4())

def fire_event(self, e: BaseEvent, level: EventLevel = None) -> None:
def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
msg = msg_from_base_event(e, level=level)

if os.environ.get("DBT_TEST_BINARY_SERIALIZATION"):
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/events/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def warn_or_error(event, node=None):
# an alternative to fire_event which only creates and logs the event value
# if the condition is met. Does nothing otherwise.
def fire_event_if(
conditional: bool, lazy_e: Callable[[], BaseEvent], level: EventLevel = None
conditional: bool, lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None
) -> None:
if conditional:
fire_event(lazy_e(), level=level)
Expand All @@ -212,7 +212,7 @@ def fire_event_if(
# this is where all the side effects happen branched by event type
# (i.e. - mutating the event history, printing to stdout, logging
# to files, etc.)
def fire_event(e: BaseEvent, level: EventLevel = None) -> None:
def fire_event(e: BaseEvent, level: Optional[EventLevel] = None) -> None:
EVENT_MANAGER.fire_event(e, level=level)


Expand Down
4 changes: 2 additions & 2 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class DbtProfileError(DbtConfigError):


class SemverError(Exception):
def __init__(self, msg: str = None):
def __init__(self, msg: Optional[str] = None):
self.msg = msg
if msg is not None:
super().__init__(msg)
Expand Down Expand Up @@ -2168,7 +2168,7 @@ class RPCCompiling(DbtRuntimeError):
CODE = 10010
MESSAGE = 'RPC server is compiling the project, call the "status" method for' " compile status"

def __init__(self, msg: str = None, node=None):
def __init__(self, msg: Optional[str] = None, node=None):
if msg is None:
msg = "compile in progress"
super().__init__(msg, node)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ class ListLogHandler(LogMessageHandler):
def __init__(
self,
level: int = logbook.NOTSET,
filter: Callable = None,
filter: Optional[Callable] = None,
bubble: bool = False,
lst: Optional[List[LogMessage]] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/generic_test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(
target: Testable,
package_name: str,
render_ctx: Dict[str, Any],
column_name: str = None,
column_name: Optional[str] = None,
) -> None:
test_name, test_args = self.extract_test_args(test, column_name)
self.args: Dict[str, Any] = test_args
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def parse_tests(self, block: TestBlock) -> None:
for test in block.tests:
self.parse_test(block, test, None)

def parse_file(self, block: FileBlock, dct: Dict = None) -> None:
def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None:
assert isinstance(block.file, SchemaSourceFile)
if not dct:
dct = yaml_from_file(block.file)
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ def _mark_dependent_errors(self, node_id, result, cause):
for dep_node_id in self.graph.get_dependent_nodes(node_id):
self._skipped_children[dep_node_id] = cause

def populate_adapter_cache(self, adapter, required_schemas: Set[BaseRelation] = None):
def populate_adapter_cache(
self, adapter, required_schemas: Optional[Set[BaseRelation]] = None
):
start_populate_cache = time.perf_counter()
if get_flags().CACHE_SELECTED_ONLY is True:
adapter.set_relations_cache(self.manifest, required_schemas=required_schemas)
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import warnings
from datetime import datetime
from typing import Dict, List
from typing import Dict, List, Optional
from contextlib import contextmanager
from dbt.adapters.factory import Adapter

Expand Down Expand Up @@ -67,7 +67,7 @@
# run_dbt(["run", "--vars", "seed_name: base"])
# If the command is expected to fail, pass in "expect_pass=False"):
# run_dbt("test"], expect_pass=False)
def run_dbt(args: List[str] = None, expect_pass=True):
def run_dbt(args: Optional[List[str]] = None, expect_pass=True):
# Ignore logbook warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook")

Expand Down Expand Up @@ -103,7 +103,7 @@ def run_dbt(args: List[str] = None, expect_pass=True):
# If you want the logs that are normally written to a file, you must
# start with the "--debug" flag. The structured schema log CI test
# will turn the logs into json, so you have to be prepared for that.
def run_dbt_and_capture(args: List[str] = None, expect_pass=True):
def run_dbt_and_capture(args: Optional[List[str]] = None, expect_pass=True):
try:
stringbuf = StringIO()
capture_stdout_logs(stringbuf)
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ flake8
flaky
freezegun==0.3.12
ipdb
mypy==1.0.1
mypy==1.3.0
pip-tools
pre-commit
protobuf
Expand Down

0 comments on commit 9269e79

Please sign in to comment.