Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy plugin added. #153

Merged
merged 1 commit into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[run]
omit =
pydantic_xml/mypy.py
42 changes: 38 additions & 4 deletions pydantic_xml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,23 +162,42 @@ def __init__(
utils.register_nsmap(nsmap)


def attr(name: Optional[str] = None, ns: Optional[str] = None, **kwargs: Any) -> Any:
_Unset: Any = pdc.PydanticUndefined


def attr(
name: Optional[str] = None,
ns: Optional[str] = None,
*,
default: Any = pdc.PydanticUndefined,
default_factory: Optional[Callable[[], Any]] = _Unset,
**kwargs: Any,
) -> Any:
"""
Marks a pydantic field as an xml attribute.

:param name: attribute name
:param ns: attribute xml namespace
:param default: the default value of the field.
:param default_factory: the factory function used to construct the default for the field.
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns, **kwargs)
return XmlEntityInfo(
EntityLocation.ATTRIBUTE,
path=name, ns=ns, default=default, default_factory=default_factory,
**kwargs,
)


def element(
tag: Optional[str] = None,
ns: Optional[str] = None,
nsmap: Optional[NsMap] = None,
nillable: bool = False,
*,
default: Any = pdc.PydanticUndefined,
default_factory: Optional[Callable[[], Any]] = _Unset,
**kwargs: Any,
) -> Any:
"""
Expand All @@ -188,17 +207,26 @@ def element(
:param ns: element xml namespace
:param nsmap: element xml namespace map
:param nillable: is element nillable. See https://www.w3.org/TR/xmlschema-1/#xsi_nil.
:param default: the default value of the field.
:param default_factory: the factory function used to construct the default for the field.
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable, **kwargs)
return XmlEntityInfo(
EntityLocation.ELEMENT,
path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory,
**kwargs,
)


def wrapped(
path: str,
entity: Optional[pd.fields.FieldInfo] = None,
ns: Optional[str] = None,
nsmap: Optional[NsMap] = None,
*,
default: Any = pdc.PydanticUndefined,
default_factory: Optional[Callable[[], Any]] = _Unset,
**kwargs: Any,
) -> Any:
"""
Expand All @@ -208,10 +236,16 @@ def wrapped(
:param path: entity path
:param ns: element xml namespace
:param nsmap: element xml namespace map
:param default: the default value of the field.
:param default_factory: the factory function used to construct the default for the field.
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=entity, **kwargs)
return XmlEntityInfo(
EntityLocation.WRAPPED,
path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory,
**kwargs,
)


class XmlModelMeta(ModelMetaclass):
Expand Down
96 changes: 96 additions & 0 deletions pydantic_xml/mypy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Callable, Optional, Tuple, Union

from mypy import nodes
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, Type
from pydantic.mypy import PydanticModelTransformer, PydanticPlugin

ATTR_FULLNAME = 'pydantic_xml.model.attr'
ELEMENT_FULLNAME = 'pydantic_xml.model.element'
WRAPPED_FULLNAME = 'pydantic_xml.model.wrapped'
ENTITIES_FULLNAME = (ATTR_FULLNAME, ELEMENT_FULLNAME, WRAPPED_FULLNAME)


def plugin(version: str) -> type[Plugin]:
return PydanticXmlPlugin


class PydanticXmlPlugin(PydanticPlugin):
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]:
sym = self.lookup_fully_qualified(fullname)
if sym and sym.fullname == ATTR_FULLNAME:
return self._attribute_callback
elif sym and sym.fullname == ELEMENT_FULLNAME:
return self._element_callback
elif sym and sym.fullname == WRAPPED_FULLNAME:
return self._wrapped_callback

return super().get_function_hook(fullname)

def _attribute_callback(self, ctx: FunctionContext) -> Type:
return super()._pydantic_field_callback(self._pop_first_args(ctx, 2))

def _element_callback(self, ctx: FunctionContext) -> Type:
return super()._pydantic_field_callback(self._pop_first_args(ctx, 4))

def _wrapped_callback(self, ctx: FunctionContext) -> Type:
return super()._pydantic_field_callback(self._pop_first_args(ctx, 4))

def _pop_first_args(self, ctx: FunctionContext, num: int) -> FunctionContext:
return FunctionContext(
arg_types=ctx.arg_types[num:],
arg_kinds=ctx.arg_kinds[num:],
callee_arg_names=ctx.callee_arg_names[num:],
arg_names=ctx.arg_names[num:],
default_return_type=ctx.default_return_type,
args=ctx.args[num:],
context=ctx.context,
api=ctx.api,
)

def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
transformer = PydanticXmlModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
return transformer.transform()


class PydanticXmlModelTransformer(PydanticModelTransformer):
@staticmethod
def get_has_default(stmt: nodes.AssignmentStmt) -> bool:
expr = stmt.rvalue
if isinstance(expr, nodes.TempNode):
return False

if (
isinstance(expr, nodes.CallExpr) and
isinstance(expr.callee, nodes.RefExpr) and
expr.callee.fullname in ENTITIES_FULLNAME
):
for arg, name in zip(expr.args, expr.arg_names):
if name == 'default':
return arg.__class__ is not nodes.EllipsisExpr
if name == 'default_factory':
return not (isinstance(arg, nodes.NameExpr) and arg.fullname == 'builtins.None')

return False

return PydanticModelTransformer.get_has_default(stmt)

@staticmethod
def get_alias_info(stmt: nodes.AssignmentStmt) -> Tuple[Union[str, None], bool]:
expr = stmt.rvalue
if isinstance(expr, nodes.TempNode):
return None, False

if (
isinstance(expr, nodes.CallExpr) and
isinstance(expr.callee, nodes.RefExpr) and
expr.callee.fullname in ENTITIES_FULLNAME
):
for arg, arg_name in zip(expr.args, expr.arg_names):
if arg_name != 'alias':
continue
if isinstance(arg, nodes.StrExpr):
return arg.value, False
else:
return None, True

return PydanticModelTransformer.get_alias_info(stmt)
Loading