Skip to content

Commit

Permalink
mypy plugin added.
Browse files Browse the repository at this point in the history
  • Loading branch information
dapper91 committed Dec 23, 2023
1 parent 73228d7 commit 463aa4d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 4 deletions.
42 changes: 38 additions & 4 deletions pydantic_xml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,23 +162,42 @@ def __init__(
utils.register_nsmap(nsmap)


def attr(name: Optional[str] = None, ns: Optional[str] = None, **kwargs: Any) -> Any:
_Unset: Any = pdc.PydanticUndefined


def attr(
name: Optional[str] = None,
ns: Optional[str] = None,
*,
default: Any = pdc.PydanticUndefined,
default_factory: Optional[Callable[[], Any]] = _Unset,
**kwargs: Any,
) -> Any:
"""
Marks a pydantic field as an xml attribute.
:param name: attribute name
:param ns: attribute xml namespace
:param default: the default value of the field.
:param default_factory: the factory function used to construct the default for the field.
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns, **kwargs)
return XmlEntityInfo(
EntityLocation.ATTRIBUTE,
path=name, ns=ns, default=default, default_factory=default_factory,
**kwargs,
)


def element(
tag: Optional[str] = None,
ns: Optional[str] = None,
nsmap: Optional[NsMap] = None,
nillable: bool = False,
*,
default: Any = pdc.PydanticUndefined,
default_factory: Optional[Callable[[], Any]] = _Unset,
**kwargs: Any,
) -> Any:
"""
Expand All @@ -188,17 +207,26 @@ def element(
:param ns: element xml namespace
:param nsmap: element xml namespace map
:param nillable: is element nillable. See https://www.w3.org/TR/xmlschema-1/#xsi_nil.
:param default: the default value of the field.
:param default_factory: the factory function used to construct the default for the field.
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable, **kwargs)
return XmlEntityInfo(
EntityLocation.ELEMENT,
path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory,
**kwargs,
)


def wrapped(
path: str,
entity: Optional[pd.fields.FieldInfo] = None,
ns: Optional[str] = None,
nsmap: Optional[NsMap] = None,
*,
default: Any = pdc.PydanticUndefined,
default_factory: Optional[Callable[[], Any]] = _Unset,
**kwargs: Any,
) -> Any:
"""
Expand All @@ -208,10 +236,16 @@ def wrapped(
:param path: entity path
:param ns: element xml namespace
:param nsmap: element xml namespace map
:param default: the default value of the field.
:param default_factory: the factory function used to construct the default for the field.
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=entity, **kwargs)
return XmlEntityInfo(
EntityLocation.WRAPPED,
path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory,
**kwargs,
)


class XmlModelMeta(ModelMetaclass):
Expand Down
96 changes: 96 additions & 0 deletions pydantic_xml/mypy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Callable, Optional, Tuple, Union

from mypy import nodes
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, Type
from pydantic.mypy import PydanticModelTransformer, PydanticPlugin

ATTR_FULLNAME = 'pydantic_xml.model.attr'
ELEMENT_FULLNAME = 'pydantic_xml.model.element'
WRAPPED_FULLNAME = 'pydantic_xml.model.wrapped'
ENTITIES_FULLNAME = (ATTR_FULLNAME, ELEMENT_FULLNAME, WRAPPED_FULLNAME)


def plugin(version: str) -> type[Plugin]:
return PydanticXmlPlugin


class PydanticXmlPlugin(PydanticPlugin):
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]:
sym = self.lookup_fully_qualified(fullname)
if sym and sym.fullname == ATTR_FULLNAME:
return self._attribute_callback
elif sym and sym.fullname == ELEMENT_FULLNAME:
return self._element_callback
elif sym and sym.fullname == WRAPPED_FULLNAME:
return self._wrapped_callback

return super().get_function_hook(fullname)

def _attribute_callback(self, ctx: FunctionContext) -> Type:
return super()._pydantic_field_callback(self._pop_first_args(ctx, 2))

def _element_callback(self, ctx: FunctionContext) -> Type:
return super()._pydantic_field_callback(self._pop_first_args(ctx, 4))

def _wrapped_callback(self, ctx: FunctionContext) -> Type:
return super()._pydantic_field_callback(self._pop_first_args(ctx, 4))

def _pop_first_args(self, ctx: FunctionContext, num: int) -> FunctionContext:
return FunctionContext(
arg_types=ctx.arg_types[num:],
arg_kinds=ctx.arg_kinds[num:],
callee_arg_names=ctx.callee_arg_names[num:],
arg_names=ctx.arg_names[num:],
default_return_type=ctx.default_return_type,
args=ctx.args[num:],
context=ctx.context,
api=ctx.api,
)

def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
transformer = PydanticXmlModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
return transformer.transform()


class PydanticXmlModelTransformer(PydanticModelTransformer):
@staticmethod
def get_has_default(stmt: nodes.AssignmentStmt) -> bool:
expr = stmt.rvalue
if isinstance(expr, nodes.TempNode):
return False

if (
isinstance(expr, nodes.CallExpr) and
isinstance(expr.callee, nodes.RefExpr) and
expr.callee.fullname in ENTITIES_FULLNAME
):
for arg, name in zip(expr.args, expr.arg_names):
if name == 'default':
return arg.__class__ is not nodes.EllipsisExpr
if name == 'default_factory':
return not (isinstance(arg, nodes.NameExpr) and arg.fullname == 'builtins.None')

return False

return PydanticModelTransformer.get_has_default(stmt)

@staticmethod
def get_alias_info(stmt: nodes.AssignmentStmt) -> Tuple[Union[str, None], bool]:
expr = stmt.rvalue
if isinstance(expr, nodes.TempNode):
return None, False

if (
isinstance(expr, nodes.CallExpr) and
isinstance(expr.callee, nodes.RefExpr) and
expr.callee.fullname in ENTITIES_FULLNAME
):
for arg, arg_name in zip(expr.args, expr.arg_names):
if arg_name != 'alias':
continue
if isinstance(arg, nodes.StrExpr):
return arg.value, False
else:
return None, True

return PydanticModelTransformer.get_alias_info(stmt)

0 comments on commit 463aa4d

Please sign in to comment.