From eaf22aa259a1f7e425b908f10bd61a13ca01151f Mon Sep 17 00:00:00 2001 From: Alyssa Coghlan Date: Tue, 10 Dec 2024 17:00:42 +1000 Subject: [PATCH] Avoid sharing extra template state between models --- datamodel_code_generator/model/base.py | 9 +- datamodel_code_generator/parser/base.py | 9 +- .../main/jsonschema/discriminator_literals.py | 9 +- .../discriminator_literals_msgspec.py | 8 +- ...riminator_literals_msgspec_keyword_only.py | 28 +++++ ...rals_msgspec_keyword_only_omit_defaults.py | 28 +++++ .../msgspec_keyword_only_omit_defaults.py | 20 ++++ .../jsonschema/discriminator_literals.json | 11 ++ tests/data/jsonschema/extra_data_msgspec.json | 7 ++ tests/data/openapi/extra_data_msgspec.json | 7 ++ tests/main/jsonschema/test_main_jsonschema.py | 102 ++++++++++++++++++ tests/main/openapi/test_main_openapi.py | 72 +++++++++++++ tests/test_infer_input_type.py | 9 +- 13 files changed, 309 insertions(+), 10 deletions(-) create mode 100644 tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only.py create mode 100644 tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only_omit_defaults.py create mode 100644 tests/data/expected/main/openapi/msgspec_keyword_only_omit_defaults.py create mode 100644 tests/data/jsonschema/extra_data_msgspec.json create mode 100644 tests/data/openapi/extra_data_msgspec.json diff --git a/datamodel_code_generator/model/base.py b/datamodel_code_generator/model/base.py index 8ffc63236..d391f2eb1 100644 --- a/datamodel_code_generator/model/base.py +++ b/datamodel_code_generator/model/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections import defaultdict +from copy import deepcopy from functools import lru_cache from pathlib import Path from typing import ( @@ -316,6 +317,8 @@ def __init__( self.reference.source = self self.extra_template_data = ( + # The supplied defaultdict will either create a new entry, + # or already contain a predefined entry for this type extra_template_data[self.name] if extra_template_data is not None else defaultdict(dict) @@ -327,10 +330,12 @@ def __init__( if base_class.reference: base_class.reference.children.append(self) - if extra_template_data: + if extra_template_data is not None: all_model_extra_template_data = extra_template_data.get(ALL_MODEL) if all_model_extra_template_data: - self.extra_template_data.update(all_model_extra_template_data) + # The deepcopy is needed here to ensure that different models don't + # end up inadvertently sharing state (such as "base_class_kwargs") + self.extra_template_data.update(deepcopy(all_model_extra_template_data)) self.methods: List[str] = methods or [] diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index 070c82d80..c89a25735 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -868,10 +868,8 @@ def check_paths( ) != property_name: continue literals = discriminator_field.data_type.literals - if ( - len(literals) == 1 and literals[0] == type_names[0] - if type_names - else None + if len(literals) == 1 and literals[0] == ( + type_names[0] if type_names else None ): has_one_literal = True if isinstance( @@ -884,7 +882,8 @@ def check_paths( 'tag', discriminator_field.represented_default ) discriminator_field.extras['is_classvar'] = True - continue + # Found the discriminator field, no need to keep looking + break for ( field_data_type ) in discriminator_field.data_type.all_data_types: diff --git a/tests/data/expected/main/jsonschema/discriminator_literals.py b/tests/data/expected/main/jsonschema/discriminator_literals.py index da91f1ac9..8cc9bf6e5 100644 --- a/tests/data/expected/main/jsonschema/discriminator_literals.py +++ b/tests/data/expected/main/jsonschema/discriminator_literals.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Literal, Optional, Union from pydantic import BaseModel, Field @@ -17,5 +17,12 @@ class Type2(BaseModel): type_: Literal['b'] = Field('b', title='Type ') +class UnrelatedType(BaseModel): + info: Optional[str] = Field( + 'Unrelated type, not involved in the discriminated union', + title='A way to check for side effects', + ) + + class Response(BaseModel): inner: Union[Type1, Type2] = Field(..., discriminator='type_', title='Inner') diff --git a/tests/data/expected/main/jsonschema/discriminator_literals_msgspec.py b/tests/data/expected/main/jsonschema/discriminator_literals_msgspec.py index 223152199..f5b4cabe6 100644 --- a/tests/data/expected/main/jsonschema/discriminator_literals_msgspec.py +++ b/tests/data/expected/main/jsonschema/discriminator_literals_msgspec.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import ClassVar, Literal, Union +from typing import ClassVar, Literal, Optional, Union from msgspec import Meta, Struct from typing_extensions import Annotated @@ -18,5 +18,11 @@ class Type2(Struct, tag_field='type_', tag='b'): type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b' +class UnrelatedType(Struct): + info: Optional[Annotated[str, Meta(title='A way to check for side effects')]] = ( + 'Unrelated type, not involved in the discriminated union' + ) + + class Response(Struct): inner: Annotated[Union[Type1, Type2], Meta(title='Inner')] diff --git a/tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only.py b/tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only.py new file mode 100644 index 000000000..589147336 --- /dev/null +++ b/tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only.py @@ -0,0 +1,28 @@ +# generated by datamodel-codegen: +# filename: discriminator_literals.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import ClassVar, Literal, Optional, Union + +from msgspec import Meta, Struct +from typing_extensions import Annotated + + +class Type1(Struct, kw_only=True, tag_field='type_', tag='a'): + type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a' + + +class Type2(Struct, kw_only=True, tag_field='type_', tag='b'): + type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b' + + +class UnrelatedType(Struct, kw_only=True): + info: Optional[Annotated[str, Meta(title='A way to check for side effects')]] = ( + 'Unrelated type, not involved in the discriminated union' + ) + + +class Response(Struct, kw_only=True): + inner: Annotated[Union[Type1, Type2], Meta(title='Inner')] diff --git a/tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only_omit_defaults.py b/tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only_omit_defaults.py new file mode 100644 index 000000000..20490b165 --- /dev/null +++ b/tests/data/expected/main/jsonschema/discriminator_literals_msgspec_keyword_only_omit_defaults.py @@ -0,0 +1,28 @@ +# generated by datamodel-codegen: +# filename: discriminator_literals.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import ClassVar, Literal, Optional, Union + +from msgspec import Meta, Struct +from typing_extensions import Annotated + + +class Type1(Struct, omit_defaults=True, kw_only=True, tag_field='type_', tag='a'): + type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a' + + +class Type2(Struct, omit_defaults=True, kw_only=True, tag_field='type_', tag='b'): + type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b' + + +class UnrelatedType(Struct, omit_defaults=True, kw_only=True): + info: Optional[Annotated[str, Meta(title='A way to check for side effects')]] = ( + 'Unrelated type, not involved in the discriminated union' + ) + + +class Response(Struct, omit_defaults=True, kw_only=True): + inner: Annotated[Union[Type1, Type2], Meta(title='Inner')] diff --git a/tests/data/expected/main/openapi/msgspec_keyword_only_omit_defaults.py b/tests/data/expected/main/openapi/msgspec_keyword_only_omit_defaults.py new file mode 100644 index 000000000..389045fdd --- /dev/null +++ b/tests/data/expected/main/openapi/msgspec_keyword_only_omit_defaults.py @@ -0,0 +1,20 @@ +# generated by datamodel-codegen: +# filename: inheritance.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Optional + +from msgspec import Struct + + +class Base(Struct, omit_defaults=True, kw_only=True): + id: str + createdAt: Optional[str] = None + version: Optional[float] = 1 + + +class Child(Base, omit_defaults=True, kw_only=True): + title: str + url: Optional[str] = 'https://example.com' diff --git a/tests/data/jsonschema/discriminator_literals.json b/tests/data/jsonschema/discriminator_literals.json index a3580cadd..2cebab8b6 100644 --- a/tests/data/jsonschema/discriminator_literals.json +++ b/tests/data/jsonschema/discriminator_literals.json @@ -21,6 +21,17 @@ }, "title": "Type2", "type": "object" + }, + "UnrelatedType": { + "properties": { + "info": { + "default": "Unrelated type, not involved in the discriminated union", + "title": "A way to check for side effects", + "type": "string" + } + }, + "title": "UnrelatedType", + "type": "object" } }, "properties": { diff --git a/tests/data/jsonschema/extra_data_msgspec.json b/tests/data/jsonschema/extra_data_msgspec.json new file mode 100644 index 000000000..def380a82 --- /dev/null +++ b/tests/data/jsonschema/extra_data_msgspec.json @@ -0,0 +1,7 @@ +{ + "#all#": { + "base_class_kwargs": { + "omit_defaults": true + } + } +} diff --git a/tests/data/openapi/extra_data_msgspec.json b/tests/data/openapi/extra_data_msgspec.json new file mode 100644 index 000000000..def380a82 --- /dev/null +++ b/tests/data/openapi/extra_data_msgspec.json @@ -0,0 +1,7 @@ +{ + "#all#": { + "base_class_kwargs": { + "omit_defaults": true + } + } +} diff --git a/tests/main/jsonschema/test_main_jsonschema.py b/tests/main/jsonschema/test_main_jsonschema.py index 19bf17df2..1efbe15c5 100644 --- a/tests/main/jsonschema/test_main_jsonschema.py +++ b/tests/main/jsonschema/test_main_jsonschema.py @@ -1,6 +1,7 @@ import json import shutil from argparse import Namespace +from collections import defaultdict from pathlib import Path from tempfile import TemporaryDirectory from unittest.mock import call @@ -14,6 +15,7 @@ from datamodel_code_generator import ( DataModelType, InputFileType, + PythonVersion, chdir, generate, ) @@ -3597,3 +3599,103 @@ def test_main_jsonschema_duration(output_model, expected_output): output_file.read_text() == (EXPECTED_JSON_SCHEMA_PATH / expected_output).read_text() ) + + +@freeze_time('2019-07-26') +@pytest.mark.skipif( + black.__version__.split('.')[0] == '19', + reason="Installed black doesn't support the old style", +) +def test_main_jsonschema_keyword_only_msgspec() -> None: + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'discriminator_literals.json'), + '--output', + str(output_file), + '--input-file-type', + 'jsonschema', + '--output-model-type', + 'msgspec.Struct', + '--keyword-only', + '--target-python-version', + '3.8', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_JSON_SCHEMA_PATH + / 'discriminator_literals_msgspec_keyword_only.py' + ).read_text() + ) + + +@freeze_time('2019-07-26') +@pytest.mark.skipif( + black.__version__.split('.')[0] == '19', + reason="Installed black doesn't support the old style", +) +def test_main_jsonschema_keyword_only_msgspec_with_extra_data() -> None: + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'discriminator_literals.json'), + '--output', + str(output_file), + '--input-file-type', + 'jsonschema', + '--output-model-type', + 'msgspec.Struct', + '--keyword-only', + '--target-python-version', + '3.8', + '--extra-template-data', + str(JSON_SCHEMA_DATA_PATH / 'extra_data_msgspec.json'), + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_JSON_SCHEMA_PATH + / 'discriminator_literals_msgspec_keyword_only_omit_defaults.py' + ).read_text() + ) + + +@freeze_time('2019-07-26') +@pytest.mark.skipif( + black.__version__.split('.')[0] == '19', + reason="Installed black doesn't support the old style", +) +def test_main_jsonschema_openapi_keyword_only_msgspec_with_extra_data() -> None: + extra_data = json.loads( + (JSON_SCHEMA_DATA_PATH / 'extra_data_msgspec.json').read_text() + ) + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + generate( + input_=JSON_SCHEMA_DATA_PATH / 'discriminator_literals.json', + output=output_file, + input_file_type=InputFileType.JsonSchema, + output_model_type=DataModelType.MsgspecStruct, + keyword_only=True, + target_python_version=PythonVersion.PY_38, + extra_template_data=defaultdict(dict, extra_data), + # Following values are implied by `msgspec.Struct` in the CLI + use_annotated=True, + field_constraints=True, + ) + assert ( + output_file.read_text() + == ( + EXPECTED_JSON_SCHEMA_PATH + / 'discriminator_literals_msgspec_keyword_only_omit_defaults.py' + ).read_text() + ) diff --git a/tests/main/openapi/test_main_openapi.py b/tests/main/openapi/test_main_openapi.py index ce752e34b..de0df79dc 100644 --- a/tests/main/openapi/test_main_openapi.py +++ b/tests/main/openapi/test_main_openapi.py @@ -1,6 +1,8 @@ +import json import platform import shutil from argparse import Namespace +from collections import defaultdict from pathlib import Path from tempfile import TemporaryDirectory from typing import List @@ -19,7 +21,10 @@ from _pytest.tmpdir import TempdirFactory from datamodel_code_generator import ( + DataModelType, InputFileType, + OpenAPIScope, + PythonVersion, chdir, generate, inferred_message, @@ -3023,3 +3028,70 @@ def test_main_openapi_keyword_only_msgspec(): output_file.read_text() == (EXPECTED_OPENAPI_PATH / 'msgspec_keyword_only.py').read_text() ) + + +@freeze_time('2019-07-26') +@pytest.mark.skipif( + black.__version__.split('.')[0] == '19', + reason="Installed black doesn't support the old style", +) +def test_main_openapi_keyword_only_msgspec_with_extra_data() -> None: + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'inheritance.yaml'), + '--output', + str(output_file), + '--input-file-type', + 'openapi', + '--output-model-type', + 'msgspec.Struct', + '--keyword-only', + '--target-python-version', + '3.8', + '--extra-template-data', + str(OPEN_API_DATA_PATH / 'extra_data_msgspec.json'), + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_OPENAPI_PATH / 'msgspec_keyword_only_omit_defaults.py' + ).read_text() + ) + + +@freeze_time('2019-07-26') +@pytest.mark.skipif( + black.__version__.split('.')[0] == '19', + reason="Installed black doesn't support the old style", +) +def test_main_generate_openapi_keyword_only_msgspec_with_extra_data() -> None: + extra_data = json.loads( + (OPEN_API_DATA_PATH / 'extra_data_msgspec.json').read_text() + ) + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + generate( + input_=OPEN_API_DATA_PATH / 'inheritance.yaml', + output=output_file, + input_file_type=InputFileType.OpenAPI, + output_model_type=DataModelType.MsgspecStruct, + keyword_only=True, + target_python_version=PythonVersion.PY_38, + extra_template_data=defaultdict(dict, extra_data), + # Following values are defaults in the CLI, but not in the API + openapi_scopes=[OpenAPIScope.Schemas], + # Following values are implied by `msgspec.Struct` in the CLI + use_annotated=True, + field_constraints=True, + ) + assert ( + output_file.read_text() + == ( + EXPECTED_OPENAPI_PATH / 'msgspec_keyword_only_omit_defaults.py' + ).read_text() + ) diff --git a/tests/test_infer_input_type.py b/tests/test_infer_input_type.py index 05a082aab..573852ab1 100644 --- a/tests/test_infer_input_type.py +++ b/tests/test_infer_input_type.py @@ -20,7 +20,13 @@ def assert_infer_input_type(file: Path, raw_data_type: InputFileType) -> None: continue assert_infer_input_type(file, InputFileType.Json) for file in (DATA_PATH / 'jsonschema').rglob('*'): - if str(file).endswith(('external_child.json', 'external_child.yaml')): + if str(file).endswith( + ( + 'external_child.json', + 'external_child.yaml', + 'extra_data_msgspec.json', + ) + ): continue assert_infer_input_type(file, InputFileType.JsonSchema) for file in (DATA_PATH / 'openapi').rglob('*'): @@ -32,6 +38,7 @@ def assert_infer_input_type(file: Path, raw_data_type: InputFileType) -> None: ( 'aliases.json', 'extra_data.json', + 'extra_data_msgspec.json', 'invalid.yaml', 'list.json', 'empty_data.json',