Skip to content

Commit

Permalink
Avoid sharing extra template state between models (#2215)
Browse files Browse the repository at this point in the history
* Avoid sharing extra template state between models

* Update black version check in test skips to >= 24

---------

Co-authored-by: Koudai Aono <[email protected]>
  • Loading branch information
ncoghlan and koxudaxi authored Dec 14, 2024
1 parent 7abc1cb commit 8eccc15
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 10 deletions.
9 changes: 7 additions & 2 deletions datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from functools import lru_cache
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -316,6 +317,8 @@ def __init__(
self.reference.source = self

self.extra_template_data = (
# The supplied defaultdict will either create a new entry,
# or already contain a predefined entry for this type
extra_template_data[self.name]
if extra_template_data is not None
else defaultdict(dict)
Expand All @@ -327,10 +330,12 @@ def __init__(
if base_class.reference:
base_class.reference.children.append(self)

if extra_template_data:
if extra_template_data is not None:
all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
if all_model_extra_template_data:
self.extra_template_data.update(all_model_extra_template_data)
# The deepcopy is needed here to ensure that different models don't
# end up inadvertently sharing state (such as "base_class_kwargs")
self.extra_template_data.update(deepcopy(all_model_extra_template_data))

self.methods: List[str] = methods or []

Expand Down
9 changes: 4 additions & 5 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,10 +868,8 @@ def check_paths(
) != property_name:
continue
literals = discriminator_field.data_type.literals
if (
len(literals) == 1 and literals[0] == type_names[0]
if type_names
else None
if len(literals) == 1 and literals[0] == (
type_names[0] if type_names else None
):
has_one_literal = True
if isinstance(
Expand All @@ -884,7 +882,8 @@ def check_paths(
'tag', discriminator_field.represented_default
)
discriminator_field.extras['is_classvar'] = True
continue
# Found the discriminator field, no need to keep looking
break
for (
field_data_type
) in discriminator_field.data_type.all_data_types:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Literal, Union
from typing import Literal, Optional, Union

from pydantic import BaseModel, Field

Expand All @@ -17,5 +17,12 @@ class Type2(BaseModel):
type_: Literal['b'] = Field('b', title='Type ')


class UnrelatedType(BaseModel):
info: Optional[str] = Field(
'Unrelated type, not involved in the discriminated union',
title='A way to check for side effects',
)


class Response(BaseModel):
inner: Union[Type1, Type2] = Field(..., discriminator='type_', title='Inner')
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import ClassVar, Literal, Union
from typing import ClassVar, Literal, Optional, Union

from msgspec import Meta, Struct
from typing_extensions import Annotated
Expand All @@ -18,5 +18,11 @@ class Type2(Struct, tag_field='type_', tag='b'):
type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b'


class UnrelatedType(Struct):
info: Optional[Annotated[str, Meta(title='A way to check for side effects')]] = (
'Unrelated type, not involved in the discriminated union'
)


class Response(Struct):
inner: Annotated[Union[Type1, Type2], Meta(title='Inner')]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# generated by datamodel-codegen:
# filename: discriminator_literals.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal, Optional, Union

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type1(Struct, kw_only=True, tag_field='type_', tag='a'):
type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a'


class Type2(Struct, kw_only=True, tag_field='type_', tag='b'):
type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b'


class UnrelatedType(Struct, kw_only=True):
info: Optional[Annotated[str, Meta(title='A way to check for side effects')]] = (
'Unrelated type, not involved in the discriminated union'
)


class Response(Struct, kw_only=True):
inner: Annotated[Union[Type1, Type2], Meta(title='Inner')]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# generated by datamodel-codegen:
# filename: discriminator_literals.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal, Optional, Union

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type1(Struct, omit_defaults=True, kw_only=True, tag_field='type_', tag='a'):
type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a'


class Type2(Struct, omit_defaults=True, kw_only=True, tag_field='type_', tag='b'):
type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b'


class UnrelatedType(Struct, omit_defaults=True, kw_only=True):
info: Optional[Annotated[str, Meta(title='A way to check for side effects')]] = (
'Unrelated type, not involved in the discriminated union'
)


class Response(Struct, omit_defaults=True, kw_only=True):
inner: Annotated[Union[Type1, Type2], Meta(title='Inner')]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# generated by datamodel-codegen:
# filename: inheritance.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Optional

from msgspec import Struct


class Base(Struct, omit_defaults=True, kw_only=True):
id: str
createdAt: Optional[str] = None
version: Optional[float] = 1


class Child(Base, omit_defaults=True, kw_only=True):
title: str
url: Optional[str] = 'https://example.com'
11 changes: 11 additions & 0 deletions tests/data/jsonschema/discriminator_literals.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
},
"title": "Type2",
"type": "object"
},
"UnrelatedType": {
"properties": {
"info": {
"default": "Unrelated type, not involved in the discriminated union",
"title": "A way to check for side effects",
"type": "string"
}
},
"title": "UnrelatedType",
"type": "object"
}
},
"properties": {
Expand Down
7 changes: 7 additions & 0 deletions tests/data/jsonschema/extra_data_msgspec.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"#all#": {
"base_class_kwargs": {
"omit_defaults": true
}
}
}
7 changes: 7 additions & 0 deletions tests/data/openapi/extra_data_msgspec.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"#all#": {
"base_class_kwargs": {
"omit_defaults": true
}
}
}
106 changes: 106 additions & 0 deletions tests/main/jsonschema/test_main_jsonschema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import shutil
from argparse import Namespace
from collections import defaultdict
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import call
Expand All @@ -14,6 +15,7 @@
from datamodel_code_generator import (
DataModelType,
InputFileType,
PythonVersion,
chdir,
generate,
)
Expand Down Expand Up @@ -3167,6 +3169,10 @@ def test_main_typed_dict_not_required_nullable():
],
)
@freeze_time('2019-07-26')
@pytest.mark.skipif(
int(black.__version__.split('.')[0]) < 24,
reason="Installed black doesn't support the new style",
)
def test_main_jsonschema_discriminator_literals(output_model, expected_output):
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
Expand Down Expand Up @@ -3597,3 +3603,103 @@ def test_main_jsonschema_duration(output_model, expected_output):
output_file.read_text()
== (EXPECTED_JSON_SCHEMA_PATH / expected_output).read_text()
)


@freeze_time('2019-07-26')
@pytest.mark.skipif(
int(black.__version__.split('.')[0]) < 24,
reason="Installed black doesn't support the new style",
)
def test_main_jsonschema_keyword_only_msgspec() -> None:
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'discriminator_literals.json'),
'--output',
str(output_file),
'--input-file-type',
'jsonschema',
'--output-model-type',
'msgspec.Struct',
'--keyword-only',
'--target-python-version',
'3.8',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_JSON_SCHEMA_PATH
/ 'discriminator_literals_msgspec_keyword_only.py'
).read_text()
)


@freeze_time('2019-07-26')
@pytest.mark.skipif(
int(black.__version__.split('.')[0]) < 24,
reason="Installed black doesn't support the new style",
)
def test_main_jsonschema_keyword_only_msgspec_with_extra_data() -> None:
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'discriminator_literals.json'),
'--output',
str(output_file),
'--input-file-type',
'jsonschema',
'--output-model-type',
'msgspec.Struct',
'--keyword-only',
'--target-python-version',
'3.8',
'--extra-template-data',
str(JSON_SCHEMA_DATA_PATH / 'extra_data_msgspec.json'),
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_JSON_SCHEMA_PATH
/ 'discriminator_literals_msgspec_keyword_only_omit_defaults.py'
).read_text()
)


@freeze_time('2019-07-26')
@pytest.mark.skipif(
int(black.__version__.split('.')[0]) < 24,
reason="Installed black doesn't support the new style",
)
def test_main_jsonschema_openapi_keyword_only_msgspec_with_extra_data() -> None:
extra_data = json.loads(
(JSON_SCHEMA_DATA_PATH / 'extra_data_msgspec.json').read_text()
)
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
generate(
input_=JSON_SCHEMA_DATA_PATH / 'discriminator_literals.json',
output=output_file,
input_file_type=InputFileType.JsonSchema,
output_model_type=DataModelType.MsgspecStruct,
keyword_only=True,
target_python_version=PythonVersion.PY_38,
extra_template_data=defaultdict(dict, extra_data),
# Following values are implied by `msgspec.Struct` in the CLI
use_annotated=True,
field_constraints=True,
)
assert (
output_file.read_text()
== (
EXPECTED_JSON_SCHEMA_PATH
/ 'discriminator_literals_msgspec_keyword_only_omit_defaults.py'
).read_text()
)
Loading

0 comments on commit 8eccc15

Please sign in to comment.