Skip to content

Commit

Permalink
fix: model generation simplified. Closes #746
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Oct 16, 2024
1 parent 1207e06 commit db0e088
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,9 @@ class MyModel(AvroBaseModel):
pydantic_class = field.get("pydantic-class")

if pydantic_class is not None:
self.imports.add("import pydantic")
return f"pydantic.{pydantic_class}"
return None

def render_dataclass_field(self, properties: str) -> str:
self.imports.add("from pydantic import Field")
return super().render_dataclass_field(properties=properties)

def add_class_imports(self) -> None:
self.imports.add("import pydantic")
self.imports.add("from dataclasses_avroschema.pydantic import AvroBaseModel")
350 changes: 191 additions & 159 deletions dataclasses_avroschema/model_generator/lang/python/base.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PydanticModelGenerator(BaseGenerator):
def __post_init__(self) -> None:
super().__post_init__()

self.base_class = "BaseModel"
self.base_class = "pydantic.BaseModel"
self.imports_dict = {
"dataclass_field": "from pydantic import Field",
}
Expand Down Expand Up @@ -41,13 +41,8 @@ class MyModel(AvroBaseModel):
pydantic_class = field.get("pydantic-class")

if pydantic_class is not None:
self.imports.add("import pydantic")
return f"pydantic.{pydantic_class}"
return None

def render_dataclass_field(self, properties: str) -> str:
self.imports.add("from pydantic import Field")
return super().render_dataclass_field(properties=properties)

def add_class_imports(self) -> None:
self.imports.add("from pydantic import BaseModel")
self.imports.add("import pydantic")
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Meta:
"""

# Pydanntic specific
PYDANTIC_FIELD = "Field($properties)"
PYDANTIC_FIELD = "pydantic.Field($properties)"

field_type_template = Template(FIELD_TYPE_TEMPLATE)
metaclass_field_template = Template(METACLASS_FIELD_TEMPLATE)
Expand Down
97 changes: 97 additions & 0 deletions tests/model_generator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def schema_primitive_types_as_defined_types() -> Dict:
"name": "expirience",
"type": {"type": "int", "unit": "years", "default": 10},
},
{"name": "second_street", "type": {"type": "string", "avro.java.string": "String"}, "default": "Batman"},
{"name": "reason", "type": ["null", {"type": "string", "avro.java.string": "String"}], "default": None},
],
}

Expand Down Expand Up @@ -259,6 +261,14 @@ def schema_with_enum_types() -> Dict:
],
"default": None,
},
{
"name": "limit_type",
"type": [
{"type": "enum", "name": "LimitTypes", "symbols": ["MIN_LIMIT", "MAX_LIMIT", "EXACT_LIMIT"]},
"null",
],
"default": "MIN_LIMIT",
},
],
}

Expand Down Expand Up @@ -304,6 +314,19 @@ def schema_with_enum_types_with_inner_default() -> Dict:
],
"default": None,
},
{
"name": "limit_type",
"type": [
{
"type": "enum",
"name": "LimitTypes",
"symbols": ["MIN_LIMIT", "MAX_LIMIT", "EXACT_LIMIT"],
"default": "MIN_LIMIT",
},
"null",
],
"default": "MIN_LIMIT",
},
],
}

Expand Down Expand Up @@ -558,6 +581,78 @@ def schema_one_to_self_relationship() -> JsonDict:
}


@pytest.fixture
def schema_with_multiple_levels_of_relationship() -> JsonDict:
return {
"type": "record",
"name": "User",
"fields": [
{"name": "name", "type": "string"},
{"name": "age", "type": "long"},
{
"name": "addresses",
"type": {
"type": "map",
"values": {
"type": "record",
"name": "Address",
"fields": [
{"name": "street", "type": "string"},
{"name": "street_number", "type": "long"},
{
"name": "house_main_event",
"type": {
"type": "record",
"name": "HouseMainEvent",
"fields": [
{
"name": "birthday",
"type": {"type": "int", "logicalType": "date"},
"default": 18181,
},
{
"name": "meeting_time",
"type": {"type": "int", "logicalType": "time-millis"},
"default": 64662000,
},
{
"name": "release_datetime",
"type": {"type": "long", "logicalType": "timestamp-millis"},
"default": 1570903062000,
},
{
"name": "event_uuid",
"type": {"type": "string", "logicalType": "uuid"},
"default": "09f00184-7721-4266-a955-21048a5cc235",
},
],
},
},
],
"doc": "An Address",
},
"name": "address",
},
},
{
"name": "crazy_union",
"type": [
"string",
{"type": "map", "values": "Address", "name": "optional_address"},
],
},
{
"name": "optional_addresses",
"type": [
"null",
{"type": "map", "values": "Address", "name": "optional_address"},
],
"default": None,
},
],
}


@pytest.fixture
def schema_with_decimal_field() -> JsonDict:
return {
Expand Down Expand Up @@ -937,12 +1032,14 @@ def with_fields_with_metadata() -> JsonDict:
"name": "fieldwithdefault",
"type": "string",
"default": "some default value",
"avro.java.string": "String",
},
{
"name": "someotherfield",
"type": "long",
"aliases": ["oldname"],
"doc": "test",
"avro.java.string": "String",
},
],
}
Expand Down
65 changes: 63 additions & 2 deletions tests/model_generator/test_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Address(AvroModel):
weight: types.Int32 = dataclasses.field(metadata={'unit': 'kg'})
pet_age: types.Int32 = 1
expirience: types.Int32 = dataclasses.field(metadata={'unit': 'years'}, default=10)
second_street: str = dataclasses.field(metadata={'avro.java.string': 'String'}, default="Batman")
reason: typing.Optional[str] = dataclasses.field(metadata={'avro.java.string': 'String'}, default=None)
"""
model_generator = ModelGenerator()
Expand Down Expand Up @@ -231,12 +233,19 @@ class Cars({templates.ENUM_PYTHON_VERSION}):
DUNA = "duna"
class LimitTypes({templates.ENUM_PYTHON_VERSION}):
MIN_LIMIT = "MIN_LIMIT"
MAX_LIMIT = "MAX_LIMIT"
EXACT_LIMIT = "EXACT_LIMIT"
@dataclasses.dataclass
class User(AvroModel):
favorite_color: FavoriteColor
primary_color: FavoriteColor
superheros: Superheros = Superheros.BATMAN
cars: typing.Optional[Cars] = None
limit_type: typing.Optional[LimitTypes] = LimitTypes.MIN_LIMIT
"""
model_generator = ModelGenerator()
result = model_generator.render(schema=schema_with_enum_types)
Expand Down Expand Up @@ -283,11 +292,21 @@ class Cars({templates.ENUM_PYTHON_VERSION}):
class Meta:
default = "ferrary"
class LimitTypes({templates.ENUM_PYTHON_VERSION}):
MIN_LIMIT = "MIN_LIMIT"
MAX_LIMIT = "MAX_LIMIT"
EXACT_LIMIT = "EXACT_LIMIT"
{templates.METACLASS_DECORATOR}
class Meta:
default = "MIN_LIMIT"
@dataclasses.dataclass
class User(AvroModel):
favorite_color: FavoriteColor
superheros: Superheros = Superheros.BATMAN
cars: typing.Optional[Cars] = None
limit_type: typing.Optional[LimitTypes] = LimitTypes.MIN_LIMIT
"""
model_generator = ModelGenerator()
result = model_generator.render(schema=schema_with_enum_types_with_inner_default)
Expand Down Expand Up @@ -534,6 +553,48 @@ class User(AvroModel):
assert result.strip() == expected_result.strip()


def test_schema_with_multiple_levels_of_relationship(
schema_with_multiple_levels_of_relationship: types.JsonDict,
) -> None:
expected_result = """
from dataclasses_avroschema import AvroModel
import dataclasses
import datetime
import typing
import uuid
@dataclasses.dataclass
class HouseMainEvent(AvroModel):
birthday: datetime.date = datetime.date(2019, 10, 12)
meeting_time: datetime.time = datetime.time(17, 57, 42)
release_datetime: datetime.datetime = datetime.datetime(2019, 10, 12, 17, 57, 42, tzinfo=datetime.timezone.utc)
event_uuid: uuid.UUID = "09f00184-7721-4266-a955-21048a5cc235"
@dataclasses.dataclass
class Address(AvroModel):
\"""
An Address
\"""
street: str
street_number: int
house_main_event: HouseMainEvent
@dataclasses.dataclass
class User(AvroModel):
name: str
age: int
addresses: typing.Dict[str, Address]
crazy_union: typing.Union[str, typing.Dict[str, Address]]
optional_addresses: typing.Optional[typing.Dict[str, Address]] = None
"""
model_generator = ModelGenerator()
result = model_generator.render(schema=schema_with_multiple_levels_of_relationship)
assert result.strip() == expected_result.strip()


def test_decimal_field(schema_with_decimal_field: types.JsonDict) -> None:
expected_result = """
from dataclasses_avroschema import AvroModel
Expand Down Expand Up @@ -717,8 +778,8 @@ def test_model_generator_with_fields_with_metadata(
@dataclasses.dataclass
class Message(AvroModel):
someotherfield: int = dataclasses.field(metadata={'aliases': ['oldname'], 'doc': 'test'})
fieldwithdefault: str = "some default value"
someotherfield: int = dataclasses.field(metadata={'aliases': ['oldname'], 'doc': 'test', 'avro.java.string': 'String'})
fieldwithdefault: str = dataclasses.field(metadata={'avro.java.string': 'String'}, default="some default value")
class Meta:
Expand Down
Loading

0 comments on commit db0e088

Please sign in to comment.