From 51c7106cc79db0df751f9a81e006a9d964920fb2 Mon Sep 17 00:00:00 2001 From: Daan Rademaker Date: Thu, 8 Aug 2024 09:42:19 +0200 Subject: [PATCH] allow processing of optional enum values --- src/pydantic_avro/from_avro/types.py | 13 +++++- tests/test_from_avro.py | 61 +++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/pydantic_avro/from_avro/types.py b/src/pydantic_avro/from_avro/types.py index 52d6dd0..07bbdcd 100644 --- a/src/pydantic_avro/from_avro/types.py +++ b/src/pydantic_avro/from_avro/types.py @@ -56,10 +56,19 @@ def logical_type_handler(t: dict) -> str: def enum_type_handler(t: dict) -> str: """Gets the enum type of a given Avro enum type and adds it to the class registry""" - name = t["type"].get("name") + if t["type"] == "enum": + # comes from a unioned enum (e.g. ["null", "enum"]) + type_info = t + else: + # comes from a direct enum + type_info = t["type"] + + name = type_info["name"] + symbols = type_info["symbols"] + if not ClassRegistry().has_class(name): enum_class = f"class {name}(str, Enum):\n" - for s in t["type"].get("symbols"): + for s in symbols: enum_class += f' {s} = "{s}"\n' ClassRegistry().add_class(name, enum_class) return name diff --git a/tests/test_from_avro.py b/tests/test_from_avro.py index d2626c3..59d1f7c 100644 --- a/tests/test_from_avro.py +++ b/tests/test_from_avro.py @@ -59,7 +59,10 @@ def test_avsc_to_pydantic_map(): "name": "Test", "type": "record", "fields": [ - {"name": "col1", "type": {"type": "map", "values": "string", "default": {}}}, + { + "name": "col1", + "type": {"type": "map", "values": "string", "default": {}}, + }, ], } ) @@ -73,7 +76,10 @@ def test_avsc_to_pydantic_map_missing_values(): "name": "Test", "type": "record", "fields": [ - {"name": "col1", "type": {"type": "map", "values": None, "default": {}}}, + { + "name": "col1", + "type": {"type": "map", "values": None, "default": {}}, + }, ], } ) @@ -215,7 +221,11 @@ def test_default(): "fields": [ {"name": "col1", "type": "string", "default": "test"}, {"name": "col2_1", "type": ["null", "string"], "default": None}, - {"name": "col2_2", "type": ["string", "null"], "default": "default_str"}, + { + "name": "col2_2", + "type": ["string", "null"], + "default": "default_str", + }, { "name": "col3", "type": {"type": "map", "values": "string"}, @@ -245,7 +255,11 @@ def test_enums(): "fields": [ { "name": "c1", - "type": {"type": "enum", "symbols": ["passed", "failed"], "name": "Status"}, + "type": { + "type": "enum", + "symbols": ["passed", "failed"], + "name": "Status", + }, }, ], } @@ -256,6 +270,32 @@ def test_enums(): assert "class Status(str, Enum):\n" ' passed = "passed"\n' ' failed = "failed"' in pydantic_code +def test_enums_nullable(): + pydantic_code = avsc_to_pydantic( + { + "name": "Test", + "type": "record", + "fields": [ + { + "name": "c1", + "type": [ + "null", + { + "type": "enum", + "symbols": ["passed", "failed"], + "name": "Status", + }, + ], + }, + ], + } + ) + + assert "class Test(BaseModel):\n" " c1: Optional[Status]" in pydantic_code + + assert "class Status(str, Enum):\n" ' passed = "passed"\n' ' failed = "failed"' in pydantic_code + + def test_enums_reuse(): pydantic_code = avsc_to_pydantic( { @@ -264,7 +304,11 @@ def test_enums_reuse(): "fields": [ { "name": "c1", - "type": {"type": "enum", "symbols": ["passed", "failed"], "name": "Status"}, + "type": { + "type": "enum", + "symbols": ["passed", "failed"], + "name": "Status", + }, }, {"name": "c2", "type": "Status"}, ], @@ -291,7 +335,12 @@ def test_unions(): { "type": "record", "name": "ARecord", - "fields": [{"name": "values", "type": {"type": "map", "values": "string"}}], + "fields": [ + { + "name": "values", + "type": {"type": "map", "values": "string"}, + } + ], }, ], },