diff --git a/pydantic_xml/serializers/factories/heterogeneous.py b/pydantic_xml/serializers/factories/heterogeneous.py index 367ae81..91dc4ab 100644 --- a/pydantic_xml/serializers/factories/heterogeneous.py +++ b/pydantic_xml/serializers/factories/heterogeneous.py @@ -93,6 +93,7 @@ def from_core_schema(schema: pcs.TupleSchema, ctx: Serializer.Context) -> Serial SchemaTypeFamily.MAPPING, SchemaTypeFamily.TYPED_MAPPING, SchemaTypeFamily.UNION, + SchemaTypeFamily.TAGGED_UNION, SchemaTypeFamily.IS_INSTANCE, SchemaTypeFamily.CALL, ): diff --git a/pydantic_xml/serializers/factories/homogeneous.py b/pydantic_xml/serializers/factories/homogeneous.py index c96575d..5259c7f 100644 --- a/pydantic_xml/serializers/factories/homogeneous.py +++ b/pydantic_xml/serializers/factories/homogeneous.py @@ -111,6 +111,7 @@ def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Co SchemaTypeFamily.MAPPING, SchemaTypeFamily.TYPED_MAPPING, SchemaTypeFamily.UNION, + SchemaTypeFamily.TAGGED_UNION, SchemaTypeFamily.IS_INSTANCE, SchemaTypeFamily.CALL, SchemaTypeFamily.TUPLE, @@ -122,6 +123,7 @@ def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Co if items_type_family not in ( SchemaTypeFamily.MODEL, SchemaTypeFamily.UNION, + SchemaTypeFamily.TAGGED_UNION, SchemaTypeFamily.TUPLE, SchemaTypeFamily.CALL, ) and ctx.entity_location is None: diff --git a/pydantic_xml/serializers/factories/named_tuple.py b/pydantic_xml/serializers/factories/named_tuple.py index 4504481..27cbca0 100644 --- a/pydantic_xml/serializers/factories/named_tuple.py +++ b/pydantic_xml/serializers/factories/named_tuple.py @@ -63,6 +63,7 @@ def from_core_schema(schema: pcs.CallSchema, ctx: Serializer.Context) -> Seriali SchemaTypeFamily.MAPPING, SchemaTypeFamily.TYPED_MAPPING, SchemaTypeFamily.UNION, + SchemaTypeFamily.TAGGED_UNION, SchemaTypeFamily.IS_INSTANCE, SchemaTypeFamily.CALL, ): diff --git a/tests/test_unions.py b/tests/test_unions.py index c88dd98..b55f6d1 100644 --- a/tests/test_unions.py +++ b/tests/test_unions.py @@ -390,6 +390,46 @@ class TestModel(RootXmlModel, tag='model'): assert_xml_equal(actual_xml, xml) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 and above") +def test_tagged_union_collection(): + from typing import Annotated + + class SubModel1(BaseXmlModel): + type: Literal['type1'] = attr() + data: int + + class SubModel2(BaseXmlModel): + type: Literal['type2'] = attr() + data: str + + class TestModel(BaseXmlModel, tag='model'): + collection: List[Annotated[Union[SubModel1, SubModel2], Field(discriminator='type')]] = element('submodel') + + xml = ''' + + 1 + a + b + 2 + + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel( + collection=[ + SubModel1(type='type1', data='1'), + SubModel2(type='type2', data='a'), + SubModel2(type='type2', data='b'), + SubModel1(type='type1', data='2'), + ], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + def test_union_snapshot(): class SubModel1(BaseXmlModel, tag='submodel'): attr1: int = attr()