Skip to content

Commit

Permalink
Fix strict optional handling in dataclasses (#15571)
Browse files Browse the repository at this point in the history
There were few cases when someone forgot to call `strict_optional_set()`
in dataclasses plugin, let's move the calls directly to two places where
they are needed for typeops. This may cause a tiny perf regression, but
is much more robust in terms of preventing bugs.
  • Loading branch information
ilevkivskyi authored Jul 3, 2023
1 parent 2e9c9b4 commit 5e4d097
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
30 changes: 17 additions & 13 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
info: TypeInfo,
kw_only: bool,
is_neither_frozen_nor_nonfrozen: bool,
api: SemanticAnalyzerPluginInterface,
) -> None:
self.name = name
self.alias = alias
Expand All @@ -116,6 +117,7 @@ def __init__(
self.info = info
self.kw_only = kw_only
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
self._api = api

def to_argument(self, current_info: TypeInfo) -> Argument:
arg_kind = ARG_POS
Expand All @@ -138,7 +140,10 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
with state.strict_optional_set(self._api.options.strict_optional):
return expand_type(
self.type, {self.info.self_type.id: fill_typevars(current_info)}
)
return self.type

def to_var(self, current_info: TypeInfo) -> Var:
Expand All @@ -165,13 +170,14 @@ def deserialize(
) -> DataclassAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
return cls(type=typ, info=info, **data)
return cls(type=typ, info=info, **data, api=api)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if self.type is not None:
self.type = map_type_from_supertype(self.type, sub_type, self.info)
with state.strict_optional_set(self._api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)


class DataclassTransformer:
Expand Down Expand Up @@ -230,12 +236,11 @@ def transform(self) -> bool:
and ("__init__" not in info.names or info.names["__init__"].plugin_generated)
and attributes
):
with state.strict_optional_set(self._api.options.strict_optional):
args = [
attr.to_argument(info)
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]
args = [
attr.to_argument(info)
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]

if info.fallback_to_any:
# Make positional args optional since we don't know their order.
Expand Down Expand Up @@ -355,8 +360,7 @@ def transform(self) -> bool:
self._add_dataclass_fields_magic_attribute()

if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
with state.strict_optional_set(self._api.options.strict_optional):
self._add_internal_replace_method(attributes)
self._add_internal_replace_method(attributes)
if "__post_init__" in info.names:
self._add_internal_post_init_method(attributes)

Expand Down Expand Up @@ -546,8 +550,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(self._api.options.strict_optional):
attr.expand_typevar_from_subtype(cls.info)
attr.expand_typevar_from_subtype(cls.info)
found_attrs[name] = attr

sym_node = cls.info.names.get(name)
Expand Down Expand Up @@ -693,6 +696,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
cls.info
),
api=self._api,
)

all_attrs = list(found_attrs.values())
Expand Down
18 changes: 15 additions & 3 deletions test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,6 @@ grouped = groupby(pairs, key=fst)
[out]

[case testDataclassReplaceOptional]
# flags: --strict-optional
from dataclasses import dataclass, replace
from typing import Optional

Expand All @@ -2107,5 +2106,18 @@ reveal_type(a)
a2 = replace(a, x=None) # OK
reveal_type(a2)
[out]
_testDataclassReplaceOptional.py:10: note: Revealed type is "_testDataclassReplaceOptional.A"
_testDataclassReplaceOptional.py:12: note: Revealed type is "_testDataclassReplaceOptional.A"
_testDataclassReplaceOptional.py:9: note: Revealed type is "_testDataclassReplaceOptional.A"
_testDataclassReplaceOptional.py:11: note: Revealed type is "_testDataclassReplaceOptional.A"

[case testDataclassStrictOptionalAlwaysSet]
from dataclasses import dataclass
from typing import Callable, Optional

@dataclass
class Description:
name_fn: Callable[[Optional[int]], Optional[str]]

def f(d: Description) -> None:
reveal_type(d.name_fn)
[out]
_testDataclassStrictOptionalAlwaysSet.py:9: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]"

0 comments on commit 5e4d097

Please sign in to comment.