From 31c88f9f5a830f41a7a7a5ea351abf8188f28a09 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 16 Feb 2021 12:13:37 -0500 Subject: [PATCH] Use updated Mashumaro code --- CHANGELOG.md | 1 + core/dbt/adapters/base/relation.py | 6 +-- core/dbt/compilation.py | 6 +-- core/dbt/config/profile.py | 8 ++-- core/dbt/config/project.py | 7 +-- core/dbt/config/runtime.py | 4 +- core/dbt/context/base.py | 1 + core/dbt/context/context_config.py | 2 +- core/dbt/context/docs.py | 1 + core/dbt/context/providers.py | 6 +-- core/dbt/contracts/connection.py | 8 ++-- core/dbt/contracts/graph/compiled.py | 2 +- core/dbt/contracts/graph/manifest.py | 8 ++-- core/dbt/contracts/graph/model_config.py | 16 +++---- core/dbt/contracts/graph/parsed.py | 8 ++-- core/dbt/contracts/graph/unparsed.py | 18 +++----- core/dbt/contracts/relation.py | 2 +- core/dbt/contracts/results.py | 10 ++--- core/dbt/contracts/rpc.py | 12 ++--- core/dbt/contracts/util.py | 12 ++--- core/dbt/dataclass_schema.py | 45 +++++++++---------- core/dbt/logger.py | 8 +++- core/dbt/parser/base.py | 4 +- core/dbt/parser/schemas.py | 2 +- core/dbt/parser/snapshots.py | 3 +- core/dbt/parser/sources.py | 4 +- core/dbt/rpc/builtins.py | 2 +- core/dbt/rpc/response_manager.py | 2 +- core/dbt/rpc/task_handler.py | 2 +- core/dbt/task/generate.py | 4 +- core/dbt/task/list.py | 2 +- core/dbt/task/printer.py | 2 +- core/dbt/task/run.py | 4 +- core/dbt/utils.py | 2 +- core/setup.py | 1 + dev_requirements.txt | 2 - test/unit/test_bigquery_adapter.py | 24 +++++----- test/unit/test_config.py | 2 +- test/unit/test_contracts_graph_parsed.py | 4 +- test/unit/test_docs_generate.py | 2 +- test/unit/test_manifest.py | 14 +++--- test/unit/utils.py | 4 +- third-party-stubs/mashumaro/config.pyi | 10 +++++ .../mashumaro/serializer/base/dict.pyi | 15 ++++--- 44 files changed, 159 insertions(+), 143 deletions(-) create mode 100644 third-party-stubs/mashumaro/config.pyi diff --git a/CHANGELOG.md b/CHANGELOG.md index f84b0618990..c25e0af1354 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ Contributors: ### Under the hood - Bump werkzeug upper bound dependency to ` Policy: @@ -185,10 +185,10 @@ def quoted(self, identifier): def create_from_source( cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any ) -> Self: - source_quoting = source.quoting.to_dict() + source_quoting = source.quoting.to_dict(omit_none=True) source_quoting.pop('column', None) quote_policy = deep_merge( - cls.get_default_quote_policy().to_dict(), + cls.get_default_quote_policy().to_dict(omit_none=True), source_quoting, kwargs.get('quote_policy', {}), ) diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 7afb3e20d74..0b36cd0664f 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -138,7 +138,7 @@ def write_graph(self, outfile: str, manifest: Manifest): """ out_graph = self.graph.copy() for node_id in self.graph.nodes(): - data = manifest.expect(node_id).to_dict() + data = manifest.expect(node_id).to_dict(omit_none=True) out_graph.add_node(node_id, **data) nx.write_gpickle(out_graph, outfile) @@ -339,7 +339,7 @@ def _recursively_prepend_ctes( model.compiled_sql = injected_sql model.extra_ctes_injected = True model.extra_ctes = prepended_ctes - model.validate(model.to_dict()) + model.validate(model.to_dict(omit_none=True)) manifest.update_node(model) @@ -388,7 +388,7 @@ def _compile_node( logger.debug("Compiling {}".format(node.unique_id)) - data = node.to_dict() + data = node.to_dict(omit_none=True) data.update({ 'compiled': False, 'compiled_sql': None, diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index 256198929ff..8ba78ad7aad 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -111,8 +111,8 @@ def to_profile_info( 'credentials': self.credentials, } if serialize_credentials: - result['config'] = self.config.to_dict() - result['credentials'] = self.credentials.to_dict() + result['config'] = self.config.to_dict(omit_none=True) + result['credentials'] = self.credentials.to_dict(omit_none=True) return result def to_target_dict(self) -> Dict[str, Any]: @@ -125,7 +125,7 @@ def to_target_dict(self) -> Dict[str, Any]: 'name': self.target_name, 'target_name': self.target_name, 'profile_name': self.profile_name, - 'config': self.config.to_dict(), + 'config': self.config.to_dict(omit_none=True), }) return target @@ -138,7 +138,7 @@ def __eq__(self, other: object) -> bool: def validate(self): try: if self.credentials: - dct = self.credentials.to_dict() + dct = self.credentials.to_dict(omit_none=True) self.credentials.validate(dct) dct = self.to_profile_info(serialize_credentials=True) ProfileConfig.validate(dct) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index fe9ca743ce5..908199d36b6 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -347,7 +347,7 @@ def create_project(self, rendered: RenderComponents) -> 'Project': # break many things quoting: Dict[str, Any] = {} if cfg.quoting is not None: - quoting = cfg.quoting.to_dict() + quoting = cfg.quoting.to_dict(omit_none=True) models: Dict[str, Any] seeds: Dict[str, Any] @@ -578,10 +578,11 @@ def to_project_config(self, with_packages=False): 'config-version': self.config_version, }) if self.query_comment: - result['query-comment'] = self.query_comment.to_dict() + result['query-comment'] = \ + self.query_comment.to_dict(omit_none=True) if with_packages: - result.update(self.packages.to_dict()) + result.update(self.packages.to_dict(omit_none=True)) return result diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index ea3c028e506..a2380210334 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -78,7 +78,7 @@ def from_parts( get_relation_class_by_name(profile.credentials.type) .get_default_quote_policy() .replace_dict(_project_quoting_dict(project, profile)) - ).to_dict() + ).to_dict(omit_none=True) cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}')) @@ -391,7 +391,7 @@ def __getattribute__(self, name): f"'UnsetConfig' object has no attribute {name}" ) - def __post_serialize__(self, dct, options=None): + def __post_serialize__(self, dct): return {} diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index ca23ecdaad1..26eb6241b0d 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -538,4 +538,5 @@ def flags(self) -> Any: def generate_base_context(cli_vars: Dict[str, Any]) -> Dict[str, Any]: ctx = BaseContext(cli_vars) + # This is not a Mashumaro to_dict call return ctx.to_dict() diff --git a/core/dbt/context/context_config.py b/core/dbt/context/context_config.py index 80467ea70d3..2176056159b 100644 --- a/core/dbt/context/context_config.py +++ b/core/dbt/context/context_config.py @@ -196,7 +196,7 @@ def calculate_node_config_dict( base=base, ) finalized = config.finalize_and_validate() - return finalized.to_dict() + return finalized.to_dict(omit_none=True) class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]): diff --git a/core/dbt/context/docs.py b/core/dbt/context/docs.py index 2ea70135688..bdf66614ae9 100644 --- a/core/dbt/context/docs.py +++ b/core/dbt/context/docs.py @@ -77,4 +77,5 @@ def generate_runtime_docs( current_project: str, ) -> Dict[str, Any]: ctx = DocsRuntimeContext(config, target, manifest, current_project) + # This is not a Mashumaro to_dict call return ctx.to_dict() diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 7515680ef19..2d9ee3bb208 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1115,7 +1115,7 @@ def graph(self) -> Dict[str, Any]: @contextproperty('model') def ctx_model(self) -> Dict[str, Any]: - return self.model.to_dict() + return self.model.to_dict(omit_none=True) @contextproperty def pre_hooks(self) -> Optional[List[Dict[str, Any]]]: @@ -1231,7 +1231,7 @@ def pre_hooks(self) -> List[Dict[str, Any]]: if isinstance(self.model, ParsedSourceDefinition): return [] return [ - h.to_dict() for h in self.model.config.pre_hook + h.to_dict(omit_none=True) for h in self.model.config.pre_hook ] @contextproperty @@ -1239,7 +1239,7 @@ def post_hooks(self) -> List[Dict[str, Any]]: if isinstance(self.model, ParsedSourceDefinition): return [] return [ - h.to_dict() for h in self.model.config.post_hook + h.to_dict(omit_none=True) for h in self.model.config.post_hook ] @contextproperty diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 86c6978c7a1..1a35f882641 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -132,7 +132,7 @@ def connection_info( ) -> Iterable[Tuple[str, Any]]: """Return an ordered iterator of key/value pairs for pretty-printing. """ - as_dict = self.to_dict(options={'keep_none': True}) + as_dict = self.to_dict(omit_none=False) connection_keys = set(self._connection_keys()) aliases: List[str] = [] if with_aliases: @@ -148,8 +148,8 @@ def _connection_keys(self) -> Tuple[str, ...]: raise NotImplementedError @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data): + data = super().__pre_deserialize__(data) data = cls.translate_aliases(data) return data @@ -159,7 +159,7 @@ def translate_aliases( ) -> Dict[str, Any]: return translate_aliases(kwargs, cls._ALIASES, recurse) - def __post_serialize__(self, dct, options=None): + def __post_serialize__(self, dct): # no super() -- do we need it? if self._ALIASES: dct.update({ diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index c8f08710b94..177d041110b 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -178,7 +178,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource: raise ValueError('invalid resource_type: {}' .format(compiled.resource_type)) - return cls.from_dict(compiled.to_dict()) + return cls.from_dict(compiled.to_dict(omit_none=True)) NonSourceCompiledNode = Union[ diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index ee24e62db84..d6000c88fd7 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -240,7 +240,7 @@ def build_edges(nodes: List[ManifestNode]): def _deepcopy(value): - return value.from_dict(value.to_dict()) + return value.from_dict(value.to_dict(omit_none=True)) class Locality(enum.IntEnum): @@ -564,11 +564,11 @@ def build_flat_graph(self): """ self.flat_graph = { 'nodes': { - k: v.to_dict(options={'keep_none': True}) + k: v.to_dict(omit_none=False) for k, v in self.nodes.items() }, 'sources': { - k: v.to_dict(options={'keep_none': True}) + k: v.to_dict(omit_none=False) for k, v in self.sources.items() } } @@ -755,7 +755,7 @@ def writable_manifest(self): # When 'to_dict' is called on the Manifest, it substitues a # WritableManifest - def __pre_serialize__(self, options=None): + def __pre_serialize__(self): return self.writable_manifest() def write(self, path): diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index 09370e5f563..580b24f8f03 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -307,7 +307,7 @@ def update_from( """ # sadly, this is a circular import from dbt.adapters.factory import get_config_class_by_name - dct = self.to_dict(options={'keep_none': True}) + dct = self.to_dict(omit_none=False) adapter_config_cls = get_config_class_by_name(adapter_type) @@ -326,12 +326,12 @@ def update_from( return self.from_dict(dct) def finalize_and_validate(self: T) -> T: - dct = self.to_dict(options={'keep_none': True}) + dct = self.to_dict(omit_none=False) self.validate(dct) return self.from_dict(dct) def replace(self, **kwargs): - dct = self.to_dict() + dct = self.to_dict(omit_none=True) mapping = self.field_mapping() for key, value in kwargs.items(): @@ -396,8 +396,8 @@ class NodeConfig(BaseConfig): full_refresh: Optional[bool] = None @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data): + data = super().__pre_deserialize__(data) field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'} # create a new dict because otherwise it gets overwritten in # tests @@ -414,8 +414,8 @@ def __pre_deserialize__(cls, data, options=None): data[new_name] = data.pop(field_name) return data - def __post_serialize__(self, dct, options=None): - dct = super().__post_serialize__(dct, options=options) + def __post_serialize__(self, dct): + dct = super().__post_serialize__(dct) field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} for field_name in field_map: if field_name in dct: @@ -480,7 +480,7 @@ def validate(cls, data): # formerly supported with GenericSnapshotConfig def finalize_and_validate(self): - data = self.to_dict() + data = self.to_dict(omit_none=True) self.validate(data) return self.from_dict(data) diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index a4020e2a487..fb62d805e27 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -99,8 +99,8 @@ class HasRelationMetadata(dbtClassMixin, Replaceable): # because it messes up the subclasses and default parameters # so hack it here @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data): + data = super().__pre_deserialize__(data) if 'database' not in data: data['database'] = None return data @@ -141,7 +141,7 @@ def patch(self, patch: 'ParsedNodePatch'): # Maybe there should be validation or restrictions # elsewhere? assert isinstance(self, dbtClassMixin) - dct = self.to_dict(options={'keep_none': True}) + dct = self.to_dict(omit_none=False) self.validate(dct) def get_materialization(self): @@ -454,7 +454,7 @@ def patch(self, patch: ParsedMacroPatch): if flags.STRICT_MODE: # What does this actually validate? assert isinstance(self, dbtClassMixin) - dct = self.to_dict(options={'keep_none': True}) + dct = self.to_dict(omit_none=False) self.validate(dct) def same_contents(self, other: Optional['ParsedMacro']) -> bool: diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 165e22e4b1a..e6656c92a40 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -231,12 +231,9 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests): external: Optional[ExternalTable] = None tags: List[str] = field(default_factory=list) - def __post_serialize__(self, dct, options=None): + def __post_serialize__(self, dct): dct = super().__post_serialize__(dct) - keep_none = False - if options and 'keep_none' in options and options['keep_none']: - keep_none = True - if not keep_none and self.freshness is None: + if 'freshness' not in dct and self.freshness is None: dct['freshness'] = None return dct @@ -261,12 +258,9 @@ class UnparsedSourceDefinition(dbtClassMixin, Replaceable): def yaml_key(self) -> 'str': return 'sources' - def __post_serialize__(self, dct, options=None): + def __post_serialize__(self, dct): dct = super().__post_serialize__(dct) - keep_none = False - if options and 'keep_none' in options and options['keep_none']: - keep_none = True - if not keep_none and self.freshness is None: + if 'freshnewss' not in dct and self.freshness is None: dct['freshness'] = None return dct @@ -290,7 +284,7 @@ class SourceTablePatch(dbtClassMixin): columns: Optional[Sequence[UnparsedColumn]] = None def to_patch_dict(self) -> Dict[str, Any]: - dct = self.to_dict() + dct = self.to_dict(omit_none=True) remove_keys = ('name') for key in remove_keys: if key in dct: @@ -327,7 +321,7 @@ class SourcePatch(dbtClassMixin, Replaceable): tags: Optional[List[str]] = None def to_patch_dict(self) -> Dict[str, Any]: - dct = self.to_dict() + dct = self.to_dict(omit_none=True) remove_keys = ('name', 'overrides', 'tables', 'path') for key in remove_keys: if key in dct: diff --git a/core/dbt/contracts/relation.py b/core/dbt/contracts/relation.py index 6bcb58c749e..41eb22c3569 100644 --- a/core/dbt/contracts/relation.py +++ b/core/dbt/contracts/relation.py @@ -52,7 +52,7 @@ def __len__(self): return len(fields(self.__class__)) def incorporate(self, **kwargs): - value = self.to_dict() + value = self.to_dict(omit_none=True) value = deep_merge(value, kwargs) return self.from_dict(value) diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 43fc9117ce6..b021b767e2d 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -97,8 +97,8 @@ class BaseResult(dbtClassMixin): message: Optional[Union[str, int]] @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data): + data = super().__pre_deserialize__(data) if 'message' not in data: data['message'] = None return data @@ -206,7 +206,7 @@ def from_execution_results( ) def write(self, path: str): - write_json(path, self.to_dict(options={'keep_none': True})) + write_json(path, self.to_dict(omit_none=False)) @dataclass @@ -448,8 +448,8 @@ class CatalogResults(dbtClassMixin): errors: Optional[List[str]] = None _compile_results: Optional[Any] = None - def __post_serialize__(self, dct, options=None): - dct = super().__post_serialize__(dct, options=options) + def __post_serialize__(self, dct): + dct = super().__post_serialize__(dct) if '_compile_results' in dct: del dct['_compile_results'] return dct diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index 213f13717bb..83d248e84d2 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -38,8 +38,8 @@ class RPCParameters(dbtClassMixin): timeout: Optional[float] @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data, omit_none=True): + data = super().__pre_deserialize__(data) if 'timeout' not in data: data['timeout'] = None if 'task_tags' not in data: @@ -428,8 +428,8 @@ class TaskTiming(dbtClassMixin): # These ought to be defaults but superclass order doesn't # allow that to work @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data): + data = super().__pre_deserialize__(data) for field_name in ('start', 'end', 'elapsed'): if field_name not in data: data[field_name] = None @@ -496,8 +496,8 @@ class PollResult(RemoteResult, TaskTiming): # These ought to be defaults but superclass order doesn't # allow that to work @classmethod - def __pre_deserialize__(cls, data, options=None): - data = super().__pre_deserialize__(data, options=options) + def __pre_deserialize__(cls, data): + data = super().__pre_deserialize__(data) for field_name in ('start', 'end', 'elapsed'): if field_name not in data: data[field_name] = None diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index c0743b62ecd..5dd329d193e 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -58,7 +58,7 @@ def merged(self, *args): class Writable: def write(self, path: str): write_json( - path, self.to_dict(options={'keep_none': True}) # type: ignore + path, self.to_dict(omit_none=False) # type: ignore ) @@ -74,7 +74,7 @@ class AdditionalPropertiesMixin: # not in the class definitions and puts them in an # _extra dict in the class @classmethod - def __pre_deserialize__(cls, data, options=None): + def __pre_deserialize__(cls, data): # dir() did not work because fields with # metadata settings are not found # The original version of this would create the @@ -93,18 +93,18 @@ def __pre_deserialize__(cls, data, options=None): else: new_dict[key] = value data = new_dict - data = super().__pre_deserialize__(data, options=options) + data = super().__pre_deserialize__(data) return data - def __post_serialize__(self, dct, options=None): - data = super().__post_serialize__(dct, options=options) + def __post_serialize__(self, dct): + data = super().__post_serialize__(dct) data.update(self.extra) if '_extra' in data: del data['_extra'] return data def replace(self, **kwargs): - dct = self.to_dict(options={'keep_none': True}) + dct = self.to_dict(omit_none=False) dct.update(kwargs) return self.from_dict(dct) diff --git a/core/dbt/dataclass_schema.py b/core/dbt/dataclass_schema.py index 1382d489c50..630ce90e343 100644 --- a/core/dbt/dataclass_schema.py +++ b/core/dbt/dataclass_schema.py @@ -1,5 +1,5 @@ from typing import ( - Type, ClassVar, Dict, cast, TypeVar + Type, ClassVar, cast, ) import re from dataclasses import fields @@ -9,29 +9,28 @@ from hologram import JsonSchemaMixin, FieldEncoder, ValidationError +# type: ignore from mashumaro import DataClassDictMixin -from mashumaro.types import SerializableEncoder, SerializableType +from mashumaro.config import ( + TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig +) +from mashumaro.types import SerializableType, SerializationStrategy -class DateTimeSerializableEncoder(SerializableEncoder[datetime]): - @classmethod - def _serialize(cls, value: datetime) -> str: +class DateTimeSerialization(SerializationStrategy): + def serialize(self, value): out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: out = out + "Z" return out - @classmethod - def _deserialize(cls, value: str) -> datetime: + def deserialize(self, value): return ( value if isinstance(value, datetime) else parse(cast(str, value)) ) -TV = TypeVar("TV") - - # This class pulls in both JsonSchemaMixin from Hologram and # DataClassDictMixin from our fork of Mashumaro. The 'to_dict' # and 'from_dict' methods come from Mashumaro. Building @@ -43,23 +42,21 @@ class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin): against the schema """ - _serializable_encoders: ClassVar[Dict[str, SerializableEncoder]] = { - 'datetime.datetime': DateTimeSerializableEncoder(), - } + class Config(MashBaseConfig): + code_generation_options = [ + TO_DICT_ADD_OMIT_NONE_FLAG, + ] + serialization_strategy = { + datetime: DateTimeSerialization(), + } + _hyphenated: ClassVar[bool] = False ADDITIONAL_PROPERTIES: ClassVar[bool] = False # This is called by the mashumaro to_dict in order to handle # nested classes. # Munges the dict that's returned. - def __post_serialize__(self, dct, options=None): - keep_none = False - if options and 'keep_none' in options and options['keep_none']: - keep_none = True - if not keep_none: # remove attributes that are None - new_dict = {k: v for k, v in dct.items() if v is not None} - dct = new_dict - + def __post_serialize__(self, dct): if self._hyphenated: new_dict = {} for key in dct: @@ -75,7 +72,7 @@ def __post_serialize__(self, dct, options=None): # This is called by the mashumaro _from_dict method, before # performing the conversion to a dict @classmethod - def __pre_deserialize__(cls, data, options=None): + def __pre_deserialize__(cls, data): if cls._hyphenated: new_dict = {} for key in data: @@ -92,8 +89,8 @@ def __pre_deserialize__(cls, data, options=None): # hologram and in mashumaro. def _local_to_dict(self, **kwargs): args = {} - if 'omit_none' in kwargs and kwargs['omit_none'] is False: - args['options'] = {'keep_none': True} + if 'omit_none' in kwargs: + args['omit_none'] = kwargs['omit_none'] return self.to_dict(**args) diff --git a/core/dbt/logger.py b/core/dbt/logger.py index 1916f49020e..95b9699ec01 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -95,7 +95,8 @@ def __call__(self, record, handler): # utils imports exceptions which imports logger... import dbt.utils log_message = super().__call__(record, handler) - return json.dumps(log_message.to_dict(), cls=dbt.utils.JSONEncoder) + dct = log_message.to_dict(omit_none=True) + return json.dumps(dct, cls=dbt.utils.JSONEncoder) class FormatterMixin: @@ -127,6 +128,7 @@ class OutputHandler(logbook.StreamHandler, FormatterMixin): The `format_string` parameter only changes the default text output, not debug mode or json. """ + def __init__( self, stream, @@ -220,7 +222,8 @@ def __init__(self, timing_info: Optional[dbtClassMixin] = None): def process(self, record): if self.timing_info is not None: - record.extra['timing_info'] = self.timing_info.to_dict() + record.extra['timing_info'] = self.timing_info.to_dict( + omit_none=True) class DbtProcessState(logbook.Processor): @@ -349,6 +352,7 @@ def make_log_dir_if_missing(log_dir): class DebugWarnings(logbook.compat.redirected_warnings): """Log warnings, except send them to 'debug' instead of 'warning' level. """ + def make_record(self, message, exception, filename, lineno): rv = super().make_record(message, exception, filename, lineno) rv.level = logbook.DEBUG diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 7f603d02d61..d2b9d9cd100 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -252,7 +252,7 @@ def _create_parsetime_node( 'raw_sql': block.contents, 'unique_id': self.generate_unique_id(name), 'config': self.config_dict(config), - 'checksum': block.file.checksum.to_dict(), + 'checksum': block.file.checksum.to_dict(omit_none=True), } dct.update(kwargs) try: @@ -301,7 +301,7 @@ def update_parsed_node_config( self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] ) -> None: # Overwrite node config - final_config_dict = parsed_node.config.to_dict() + final_config_dict = parsed_node.config.to_dict(omit_none=True) final_config_dict.update(config_dict) # re-mangle hooks, in case we got new ones self._mangle_hooks(final_config_dict) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index f176a8e1731..1a165b7a452 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -385,7 +385,7 @@ def create_test_node( 'config': self.config_dict(config), 'test_metadata': test_metadata, 'column_name': column_name, - 'checksum': FileHash.empty().to_dict(), + 'checksum': FileHash.empty().to_dict(omit_none=True), } try: ParsedSchemaTestNode.validate(dct) diff --git a/core/dbt/parser/snapshots.py b/core/dbt/parser/snapshots.py index 16e5c52ab0d..db977ab4c1d 100644 --- a/core/dbt/parser/snapshots.py +++ b/core/dbt/parser/snapshots.py @@ -68,7 +68,8 @@ def get_fqn(self, path: str, name: str) -> List[str]: def transform(self, node: IntermediateSnapshotNode) -> ParsedSnapshotNode: try: - parsed_node = ParsedSnapshotNode.from_dict(node.to_dict()) + dct = node.to_dict(omit_none=True) + parsed_node = ParsedSnapshotNode.from_dict(dct) self.set_snapshot_attributes(parsed_node) return parsed_node except ValidationError as exc: diff --git a/core/dbt/parser/sources.py b/core/dbt/parser/sources.py index 30f018f6029..b6ef559ffd2 100644 --- a/core/dbt/parser/sources.py +++ b/core/dbt/parser/sources.py @@ -49,8 +49,8 @@ def patch_source( if patch is None: return unpatched - source_dct = unpatched.source.to_dict() - table_dct = unpatched.table.to_dict() + source_dct = unpatched.source.to_dict(omit_none=True) + table_dct = unpatched.table.to_dict(omit_none=True) patch_path: Optional[Path] = None source_table_patch: Optional[SourceTablePatch] = None diff --git a/core/dbt/rpc/builtins.py b/core/dbt/rpc/builtins.py index 7e900ca5357..a7ac1d77879 100644 --- a/core/dbt/rpc/builtins.py +++ b/core/dbt/rpc/builtins.py @@ -177,7 +177,7 @@ def poll_complete( def _dict_logs(logs: List[LogMessage]) -> List[Dict[str, Any]]: - return [log.to_dict() for log in logs] + return [log.to_dict(omit_none=True) for log in logs] class Poll(RemoteBuiltinMethod[PollParameters, PollResult]): diff --git a/core/dbt/rpc/response_manager.py b/core/dbt/rpc/response_manager.py index 1d44f7e0cbe..a1beb1810f6 100644 --- a/core/dbt/rpc/response_manager.py +++ b/core/dbt/rpc/response_manager.py @@ -97,7 +97,7 @@ def _get_responses(cls, requests, dispatcher): # Note: errors in to_dict do not show up anywhere in # the output and all you get is a generic 500 error output.result = \ - output.result.to_dict(options={'keep_none': True}) + output.result.to_dict(omit_none=False) yield output @classmethod diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index c657f4e3de7..ce2f320fc1f 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -391,7 +391,7 @@ def get_result(self) -> RemoteResult: except RPCException as exc: # RPC Exceptions come already preserialized for the jsonrpc # framework - exc.logs = [log.to_dict() for log in self.logs] + exc.logs = [log.to_dict(omit_none=True) for log in self.logs] exc.tags = self.tags raise diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 86c3ac57b85..e37041e39a1 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -114,8 +114,8 @@ def make_unique_id_map( if unique_id in sources: dbt.exceptions.raise_ambiguous_catalog_match( unique_id, - sources[unique_id].to_dict(), - table.to_dict(), + sources[unique_id].to_dict(omit_none=True), + table.to_dict(omit_none=True), ) else: sources[unique_id] = table.replace(unique_id=unique_id) diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 5f91cf052a7..d78223eef8b 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -110,7 +110,7 @@ def generate_json(self): for node in self._iterate_selected_nodes(): yield json.dumps({ k: v - for k, v in node.to_dict(options={'keep_none': True}).items() + for k, v in node.to_dict(omit_none=False).items() if k in self.ALLOWED_KEYS }) diff --git a/core/dbt/task/printer.py b/core/dbt/task/printer.py index 7313e4c9b0b..917d545d77f 100644 --- a/core/dbt/task/printer.py +++ b/core/dbt/task/printer.py @@ -169,7 +169,7 @@ def print_snapshot_result_line( info, status, logger_fn = get_printable_result( result, 'snapshotted', 'snapshotting') - cfg = model.config.to_dict() + cfg = model.config.to_dict(omit_none=True) msg = "{info} {description}".format( info=info, description=description, **cfg) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 0fbbdde0e4d..e397551cd8f 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -117,7 +117,7 @@ def track_model_run(index, num_nodes, run_model_result): "hashed_contents": utils.get_hashed_contents( run_model_result.node ), - "timing": [t.to_dict() for t in run_model_result.timing], + "timing": [t.to_dict(omit_none=True) for t in run_model_result.timing], }) @@ -193,7 +193,7 @@ def _build_run_model_result(self, model, context): result = context['load_result']('main') adapter_response = {} if isinstance(result.response, dbtClassMixin): - adapter_response = result.response.to_dict() + adapter_response = result.response.to_dict(omit_none=True) return RunResult( node=model, status=RunStatus.Success, diff --git a/core/dbt/utils.py b/core/dbt/utils.py index c4618a7ef27..031a53c6e7a 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -320,7 +320,7 @@ def default(self, obj): if hasattr(obj, 'to_dict'): # if we have a to_dict we should try to serialize the result of # that! - return obj.to_dict() + return obj.to_dict(omit_none=True) return super().default(obj) diff --git a/core/setup.py b/core/setup.py index 7e2ec6c1713..5eb29bdc0b5 100644 --- a/core/setup.py +++ b/core/setup.py @@ -72,6 +72,7 @@ def read(fname): 'dataclasses==0.6;python_version<"3.7"', 'hologram==0.0.13', 'logbook>=1.5,<1.6', + 'mashumaro==2.0', 'typing-extensions>=3.7.4,<3.8', # the following are all to match snowflake-connector-python 'requests>=2.18.0,<2.24.0', diff --git a/dev_requirements.txt b/dev_requirements.txt index 2feffaf4b9e..a7887c64529 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -13,5 +13,3 @@ mypy==0.782 wheel twine pytest-logbook>=1.2.0,<1.3 -git+https://github.com/fishtown-analytics/dbt-mashumaro.git@dbt-customizations -jsonschema diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 00a4ef3b885..9e74afdd569 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -659,7 +659,7 @@ def test_parse_partition_by(self): self.assertEqual( adapter.parse_partition_by({ "field": "ts", - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "date", "granularity": "day" @@ -670,7 +670,7 @@ def test_parse_partition_by(self): adapter.parse_partition_by({ "field": "ts", "data_type": "date", - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "date", "granularity": "day" @@ -683,7 +683,7 @@ def test_parse_partition_by(self): "data_type": "date", "granularity": "MONTH" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "date", "granularity": "MONTH" @@ -696,7 +696,7 @@ def test_parse_partition_by(self): "data_type": "date", "granularity": "YEAR" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "date", "granularity": "YEAR" @@ -709,7 +709,7 @@ def test_parse_partition_by(self): "data_type": "timestamp", "granularity": "HOUR" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "timestamp", "granularity": "HOUR" @@ -722,7 +722,8 @@ def test_parse_partition_by(self): "data_type": "timestamp", "granularity": "MONTH" - }).to_dict(), { + }).to_dict(omit_none=True + ), { "field": "ts", "data_type": "timestamp", "granularity": "MONTH" @@ -735,7 +736,7 @@ def test_parse_partition_by(self): "data_type": "timestamp", "granularity": "YEAR" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "timestamp", "granularity": "YEAR" @@ -748,7 +749,7 @@ def test_parse_partition_by(self): "data_type": "datetime", "granularity": "HOUR" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "datetime", "granularity": "HOUR" @@ -761,7 +762,7 @@ def test_parse_partition_by(self): "data_type": "datetime", "granularity": "MONTH" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "datetime", "granularity": "MONTH" @@ -774,7 +775,7 @@ def test_parse_partition_by(self): "data_type": "datetime", "granularity": "YEAR" - }).to_dict(), { + }).to_dict(omit_none=True), { "field": "ts", "data_type": "datetime", "granularity": "YEAR" @@ -795,7 +796,8 @@ def test_parse_partition_by(self): "end": 100, "interval": 20 } - }).to_dict(), { + }).to_dict(omit_none=True + ), { "field": "id", "data_type": "int64", "granularity": "day", diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 60754702c31..12858baeec4 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -776,7 +776,7 @@ def test_all_overrides(self): LocalPackage(local='foo'), GitPackage(git='git@example.com:fishtown-analytics/dbt-utils.git', revision='test-rev') ])) - str(project) + str(project) # this does the equivalent of project.to_project_config(with_packages=True) json.dumps(project.to_project_config()) def test_string_run_hooks(self): diff --git a/test/unit/test_contracts_graph_parsed.py b/test/unit/test_contracts_graph_parsed.py index 7e968ed9fe5..f2fb11b3abd 100644 --- a/test/unit/test_contracts_graph_parsed.py +++ b/test/unit/test_contracts_graph_parsed.py @@ -1618,7 +1618,7 @@ def test_timestamp_snapshot_ok(basic_timestamp_snapshot_dict, basic_timestamp_sn assert_symmetric(node, node_dict, ParsedSnapshotNode) assert_symmetric(inter, node_dict, IntermediateSnapshotNode) - assert ParsedSnapshotNode.from_dict(inter.to_dict()) == node + assert ParsedSnapshotNode.from_dict(inter.to_dict(omit_none=True)) == node assert node.is_refable is True assert node.is_ephemeral is False pickle.loads(pickle.dumps(node)) @@ -1631,7 +1631,7 @@ def test_check_snapshot_ok(basic_check_snapshot_dict, basic_check_snapshot_objec assert_symmetric(node, node_dict, ParsedSnapshotNode) assert_symmetric(inter, node_dict, IntermediateSnapshotNode) - assert ParsedSnapshotNode.from_dict(inter.to_dict()) == node + assert ParsedSnapshotNode.from_dict(inter.to_dict(omit_none=True)) == node assert node.is_refable is True assert node.is_ephemeral is False pickle.loads(pickle.dumps(node)) diff --git a/test/unit/test_docs_generate.py b/test/unit/test_docs_generate.py index 96dc568ff88..49576e9e100 100644 --- a/test/unit/test_docs_generate.py +++ b/test/unit/test_docs_generate.py @@ -32,7 +32,7 @@ def generate_catalog_dict(self, columns): sources=sources, errors=None, ) - return result.to_dict(options={'keep_none': True})['nodes'] + return result.to_dict(omit_none=False)['nodes'] def test__unflatten_empty(self): columns = {} diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index c88c6c2178e..a776837e609 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -212,9 +212,9 @@ def setUp(self): ), } for node in self.nested_nodes.values(): - node.validate(node.to_dict()) + node.validate(node.to_dict(omit_none=True)) for source in self.sources.values(): - source.validate(source.to_dict()) + source.validate(source.to_dict(omit_none=True)) os.environ['DBT_ENV_CUSTOM_ENV_key'] = 'value' @@ -229,7 +229,7 @@ def test__no_nodes(self): metadata=ManifestMetadata(generated_at=datetime.utcnow()), ) self.assertEqual( - manifest.writable_manifest().to_dict(), + manifest.writable_manifest().to_dict(omit_none=True), { 'nodes': {}, 'sources': {}, @@ -258,7 +258,7 @@ def test__nested_nodes(self): exposures={}, selectors={}, metadata=ManifestMetadata(generated_at=datetime.utcnow()), ) - serialized = manifest.writable_manifest().to_dict() + serialized = manifest.writable_manifest().to_dict(omit_none=True) self.assertEqual(serialized['metadata']['generated_at'], '2018-02-14T09:15:13Z') self.assertEqual(serialized['docs'], {}) self.assertEqual(serialized['disabled'], []) @@ -371,7 +371,7 @@ def test_no_nodes_with_metadata(self, mock_user): metadata=metadata, files={}, exposures={}) self.assertEqual( - manifest.writable_manifest().to_dict(), + manifest.writable_manifest().to_dict(omit_none=True), { 'nodes': {}, 'sources': {}, @@ -612,7 +612,7 @@ def test__no_nodes(self): manifest = Manifest(nodes={}, sources={}, macros={}, docs={}, selectors={}, disabled=[], metadata=metadata, files={}, exposures={}) self.assertEqual( - manifest.writable_manifest().to_dict(), + manifest.writable_manifest().to_dict(omit_none=True), { 'nodes': {}, 'macros': {}, @@ -640,7 +640,7 @@ def test__nested_nodes(self): disabled=[], selectors={}, metadata=ManifestMetadata(generated_at=datetime.utcnow()), files={}, exposures={}) - serialized = manifest.writable_manifest().to_dict() + serialized = manifest.writable_manifest().to_dict(omit_none=True) self.assertEqual(serialized['metadata']['generated_at'], '2018-02-14T09:15:13Z') self.assertEqual(serialized['disabled'], []) parent_map = serialized['parent_map'] diff --git a/test/unit/utils.py b/test/unit/utils.py index 96bb0e2f140..fe26e472042 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -145,7 +145,7 @@ def setUp(self): super().setUp() def assert_to_dict(self, obj, dct): - self.assertEqual(obj.to_dict(), dct) + self.assertEqual(obj.to_dict(omit_none=True), dct) def assert_from_dict(self, obj, dct, cls=None): if cls is None: @@ -185,7 +185,7 @@ def compare_dicts(dict1, dict2): def assert_to_dict(obj, dct): - assert obj.to_dict() == dct + assert obj.to_dict(omit_none=True) == dct def assert_from_dict(obj, dct, cls=None): diff --git a/third-party-stubs/mashumaro/config.pyi b/third-party-stubs/mashumaro/config.pyi new file mode 100644 index 00000000000..04a5cb1070b --- /dev/null +++ b/third-party-stubs/mashumaro/config.pyi @@ -0,0 +1,10 @@ +from mashumaro.types import SerializationStrategy as SerializationStrategy +from typing import Any, Callable, Dict, List, Union + +TO_DICT_ADD_OMIT_NONE_FLAG: str +SerializationStrategyValueType = Union[SerializationStrategy, Dict[str, Union[str, Callable]]] + +class BaseConfig: + debug: bool = ... + code_generation_options: List[str] = ... + serialization_strategy: Dict[Any, SerializationStrategyValueType] = ... diff --git a/third-party-stubs/mashumaro/serializer/base/dict.pyi b/third-party-stubs/mashumaro/serializer/base/dict.pyi index 568079dac8e..6a1e53f18c8 100644 --- a/third-party-stubs/mashumaro/serializer/base/dict.pyi +++ b/third-party-stubs/mashumaro/serializer/base/dict.pyi @@ -2,10 +2,15 @@ from typing import Any, Mapping, Dict, Optional class DataClassDictMixin: def __init_subclass__(cls, **kwargs: Any) -> None: ... - def __pre_serialize__(self, options: Optional[Dict[str, Any]]) -> Any: ... - def __post_serialize__(self, dct: Mapping, options: Optional[Dict[str, Any]]) -> Any: ... + def __pre_serialize__(self) -> Any: ... + def __post_serialize__(self, dct: Mapping) -> Any: ... @classmethod - def __pre_deserialize__(cls: Any, dct: Mapping, options: Optional[Dict[str, Any]]) -> Any: ... - def to_dict( self, use_bytes: bool = False, use_enum: bool = False, use_datetime: bool = False, options: Optional[Dict[str, Any]] = None) -> dict: ... + def __pre_deserialize__(cls: Any, dct: Mapping) -> Any: ... + # This is absolutely totally wrong. This is *not* the signature of the Mashumaro to_dict + # But mypy insists that the DataClassDictMixin to_dict and the JsonSchemaMixin to_dict + # must have the same signatures now that we have an 'omit_none' flag on the Mashumaro to_dict. + # There is no 'validate = False' in Mashumaro. + # Could not find a way to tell mypy to ignore it. + def to_dict( self, omit_none = False, validate = False) -> dict: ... @classmethod - def from_dict( cls, d: Mapping, use_bytes: bool = False, use_enum: bool = False, use_datetime: bool = False, options: Optional[Dict[str, Any]] = None) -> Any: ... + def from_dict( cls, d: Mapping, use_bytes: bool = False, use_enum: bool = False, use_datetime: bool = False) -> Any: ...