diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..63af5ab --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +omit = + pydantic_xml/mypy.py diff --git a/pydantic_xml/model.py b/pydantic_xml/model.py index 2fb2cdf..3c99983 100644 --- a/pydantic_xml/model.py +++ b/pydantic_xml/model.py @@ -162,16 +162,32 @@ def __init__( utils.register_nsmap(nsmap) -def attr(name: Optional[str] = None, ns: Optional[str] = None, **kwargs: Any) -> Any: +_Unset: Any = pdc.PydanticUndefined + + +def attr( + name: Optional[str] = None, + ns: Optional[str] = None, + *, + default: Any = pdc.PydanticUndefined, + default_factory: Optional[Callable[[], Any]] = _Unset, + **kwargs: Any, +) -> Any: """ Marks a pydantic field as an xml attribute. :param name: attribute name :param ns: attribute xml namespace + :param default: the default value of the field. + :param default_factory: the factory function used to construct the default for the field. :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns, **kwargs) + return XmlEntityInfo( + EntityLocation.ATTRIBUTE, + path=name, ns=ns, default=default, default_factory=default_factory, + **kwargs, + ) def element( @@ -179,6 +195,9 @@ def element( ns: Optional[str] = None, nsmap: Optional[NsMap] = None, nillable: bool = False, + *, + default: Any = pdc.PydanticUndefined, + default_factory: Optional[Callable[[], Any]] = _Unset, **kwargs: Any, ) -> Any: """ @@ -188,10 +207,16 @@ def element( :param ns: element xml namespace :param nsmap: element xml namespace map :param nillable: is element nillable. See https://www.w3.org/TR/xmlschema-1/#xsi_nil. + :param default: the default value of the field. + :param default_factory: the factory function used to construct the default for the field. :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable, **kwargs) + return XmlEntityInfo( + EntityLocation.ELEMENT, + path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory, + **kwargs, + ) def wrapped( @@ -199,6 +224,9 @@ def wrapped( entity: Optional[pd.fields.FieldInfo] = None, ns: Optional[str] = None, nsmap: Optional[NsMap] = None, + *, + default: Any = pdc.PydanticUndefined, + default_factory: Optional[Callable[[], Any]] = _Unset, **kwargs: Any, ) -> Any: """ @@ -208,10 +236,16 @@ def wrapped( :param path: entity path :param ns: element xml namespace :param nsmap: element xml namespace map + :param default: the default value of the field. + :param default_factory: the factory function used to construct the default for the field. :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=entity, **kwargs) + return XmlEntityInfo( + EntityLocation.WRAPPED, + path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory, + **kwargs, + ) class XmlModelMeta(ModelMetaclass): diff --git a/pydantic_xml/mypy.py b/pydantic_xml/mypy.py new file mode 100644 index 0000000..84f22b6 --- /dev/null +++ b/pydantic_xml/mypy.py @@ -0,0 +1,96 @@ +from typing import Callable, Optional, Tuple, Union + +from mypy import nodes +from mypy.plugin import ClassDefContext, FunctionContext, Plugin, Type +from pydantic.mypy import PydanticModelTransformer, PydanticPlugin + +ATTR_FULLNAME = 'pydantic_xml.model.attr' +ELEMENT_FULLNAME = 'pydantic_xml.model.element' +WRAPPED_FULLNAME = 'pydantic_xml.model.wrapped' +ENTITIES_FULLNAME = (ATTR_FULLNAME, ELEMENT_FULLNAME, WRAPPED_FULLNAME) + + +def plugin(version: str) -> type[Plugin]: + return PydanticXmlPlugin + + +class PydanticXmlPlugin(PydanticPlugin): + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: + sym = self.lookup_fully_qualified(fullname) + if sym and sym.fullname == ATTR_FULLNAME: + return self._attribute_callback + elif sym and sym.fullname == ELEMENT_FULLNAME: + return self._element_callback + elif sym and sym.fullname == WRAPPED_FULLNAME: + return self._wrapped_callback + + return super().get_function_hook(fullname) + + def _attribute_callback(self, ctx: FunctionContext) -> Type: + return super()._pydantic_field_callback(self._pop_first_args(ctx, 2)) + + def _element_callback(self, ctx: FunctionContext) -> Type: + return super()._pydantic_field_callback(self._pop_first_args(ctx, 4)) + + def _wrapped_callback(self, ctx: FunctionContext) -> Type: + return super()._pydantic_field_callback(self._pop_first_args(ctx, 4)) + + def _pop_first_args(self, ctx: FunctionContext, num: int) -> FunctionContext: + return FunctionContext( + arg_types=ctx.arg_types[num:], + arg_kinds=ctx.arg_kinds[num:], + callee_arg_names=ctx.callee_arg_names[num:], + arg_names=ctx.arg_names[num:], + default_return_type=ctx.default_return_type, + args=ctx.args[num:], + context=ctx.context, + api=ctx.api, + ) + + def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool: + transformer = PydanticXmlModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config) + return transformer.transform() + + +class PydanticXmlModelTransformer(PydanticModelTransformer): + @staticmethod + def get_has_default(stmt: nodes.AssignmentStmt) -> bool: + expr = stmt.rvalue + if isinstance(expr, nodes.TempNode): + return False + + if ( + isinstance(expr, nodes.CallExpr) and + isinstance(expr.callee, nodes.RefExpr) and + expr.callee.fullname in ENTITIES_FULLNAME + ): + for arg, name in zip(expr.args, expr.arg_names): + if name == 'default': + return arg.__class__ is not nodes.EllipsisExpr + if name == 'default_factory': + return not (isinstance(arg, nodes.NameExpr) and arg.fullname == 'builtins.None') + + return False + + return PydanticModelTransformer.get_has_default(stmt) + + @staticmethod + def get_alias_info(stmt: nodes.AssignmentStmt) -> Tuple[Union[str, None], bool]: + expr = stmt.rvalue + if isinstance(expr, nodes.TempNode): + return None, False + + if ( + isinstance(expr, nodes.CallExpr) and + isinstance(expr.callee, nodes.RefExpr) and + expr.callee.fullname in ENTITIES_FULLNAME + ): + for arg, arg_name in zip(expr.args, expr.arg_names): + if arg_name != 'alias': + continue + if isinstance(arg, nodes.StrExpr): + return arg.value, False + else: + return None, True + + return PydanticModelTransformer.get_alias_info(stmt)