diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index 55fe5428..dc3504a3 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -27,8 +27,13 @@ from pydantic import BaseModel from datamodel_code_generator.format import CodeFormatter, PythonVersion -from datamodel_code_generator.imports import IMPORT_ANNOTATIONS, Import, Imports +from datamodel_code_generator.imports import ( + IMPORT_ANNOTATIONS, + Import, + Imports, +) from datamodel_code_generator.model import pydantic as pydantic_model +from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2 from datamodel_code_generator.model.base import ( ALL_MODEL, UNDEFINED, @@ -722,6 +727,77 @@ def __extract_inherited_enum(cls, models: List[DataModel]) -> None: ) models.remove(model) + def __apply_discriminator_type( + self, + models: List[DataModel], + imports: Imports, + ) -> None: + for model in models: + for field in model.fields: + discriminator = field.extras.get('discriminator') + if not discriminator or not isinstance(discriminator, dict): + continue + property_name = discriminator.get('propertyName') + if not property_name: # pragma: no cover + continue + mapping = discriminator.get('mapping', {}) + for data_type in field.data_type.data_types: + if not data_type.reference: # pragma: no cover + continue + discriminator_model = data_type.reference.source + if not isinstance( # pragma: no cover + discriminator_model, + (pydantic_model.BaseModel, pydantic_model_v2.BaseModel), + ): + continue # pragma: no cover + type_name = None + if mapping: + for name, path in mapping.items(): + if ( + discriminator_model.path.split('#/')[-1] + != path.split('#/')[-1] + ): + # TODO: support external reference + continue + type_name = name + else: + type_name = discriminator_model.path.split('/')[-1] + if not type_name: # pragma: no cover + raise RuntimeError( + f'Discriminator type is not found. {data_type.reference.path}' + ) + has_one_literal = False + for discriminator_field in discriminator_model.fields: + if ( + discriminator_field.original_name + or discriminator_field.name + ) != property_name: + continue + literals = discriminator_field.data_type.literals + if len(literals) == 1 and literals[0] == type_name: + has_one_literal = True + continue + for ( + field_data_type + ) in discriminator_field.data_type.all_data_types: + if field_data_type.reference: # pragma: no cover + field_data_type.remove_reference() + discriminator_field.data_type = self.data_type( + literals=[type_name] + ) + discriminator_field.data_type.parent = discriminator_field + discriminator_field.required = True + imports.append(discriminator_field.imports) + has_one_literal = True + if not has_one_literal: + discriminator_model.fields.append( + self.data_model_field_type( + name=property_name, + data_type=self.data_type(literals=[type_name]), + required=True, + ) + ) + @classmethod def _create_set_from_list(cls, data_type: DataType) -> Optional[DataType]: if data_type.is_list: @@ -1155,6 +1231,7 @@ class Processed(NamedTuple): self.__override_required_field(models) self.__sort_models(models, imports) self.__set_one_literal_on_default(models) + self.__apply_discriminator_type(models, imports) processed_models.append( Processed(module, models, init, imports, scoped_model_resolver) diff --git a/tests/data/expected/main/main_openapi_discriminator/output.py b/tests/data/expected/main/main_openapi_discriminator/output.py index 3e582759..9cd6e45f 100644 --- a/tests/data/expected/main/main_openapi_discriminator/output.py +++ b/tests/data/expected/main/main_openapi_discriminator/output.py @@ -8,25 +8,27 @@ from typing import Optional, Union from pydantic import BaseModel, Field +from typing_extensions import Literal class Type(Enum): my_first_object = 'my_first_object' my_second_object = 'my_second_object' + my_third_object = 'my_third_object' class ObjectBase(BaseModel): name: Optional[str] = Field(None, description='Name of the object') - type: Optional[Type] = Field(None, description='Object type') + type: Literal['type1'] = Field(..., description='Object type') class CreateObjectRequest(ObjectBase): name: str = Field(..., description='Name of the object') - type: Type = Field(..., description='Object type') + type: Literal['type2'] = Field(..., description='Object type') class UpdateObjectRequest(ObjectBase): - pass + type: Literal['type3'] class Demo(BaseModel): diff --git a/tests/data/expected/main/main_openapi_discriminator_without_mapping/output.py b/tests/data/expected/main/main_openapi_discriminator_without_mapping/output.py new file mode 100644 index 00000000..2b91db08 --- /dev/null +++ b/tests/data/expected/main/main_openapi_discriminator_without_mapping/output.py @@ -0,0 +1,37 @@ +# generated by datamodel-codegen: +# filename: discriminator_without_mapping.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, Field +from typing_extensions import Literal + + +class Type(Enum): + my_first_object = 'my_first_object' + my_second_object = 'my_second_object' + my_third_object = 'my_third_object' + + +class ObjectBase(BaseModel): + name: Optional[str] = Field(None, description='Name of the object') + type: Literal['ObjectBase'] = Field(..., description='Object type') + + +class CreateObjectRequest(ObjectBase): + name: str = Field(..., description='Name of the object') + type: Literal['CreateObjectRequest'] = Field(..., description='Object type') + + +class UpdateObjectRequest(ObjectBase): + type: Literal['UpdateObjectRequest'] + + +class Demo(BaseModel): + __root__: Union[ObjectBase, CreateObjectRequest, UpdateObjectRequest] = Field( + ..., discriminator='type' + ) diff --git a/tests/data/openapi/discriminator.yaml b/tests/data/openapi/discriminator.yaml index 9a611ae1..fd7c7e67 100644 --- a/tests/data/openapi/discriminator.yaml +++ b/tests/data/openapi/discriminator.yaml @@ -14,6 +14,7 @@ components: enum: - my_first_object - my_second_object + - my_third_object CreateObjectRequest: description: Request schema for object creation type: object @@ -35,5 +36,7 @@ components: discriminator: propertyName: type mapping: - type1: "#/components/schemas/Schema1" - type2: "#/components/schemas/Schema2" + type1: "#/components/schemas/ObjectBase" + type2: "#/components/schemas/CreateObjectRequest" + type3: "#/components/schemas/UpdateObjectRequest" + diff --git a/tests/data/openapi/discriminator_without_mapping.yaml b/tests/data/openapi/discriminator_without_mapping.yaml new file mode 100644 index 00000000..451beba2 --- /dev/null +++ b/tests/data/openapi/discriminator_without_mapping.yaml @@ -0,0 +1,38 @@ +openapi: "3.0.0" +components: + schemas: + ObjectBase: + description: Object schema + type: object + properties: + name: + description: Name of the object + type: string + type: + description: Object type + type: string + enum: + - my_first_object + - my_second_object + - my_third_object + CreateObjectRequest: + description: Request schema for object creation + type: object + allOf: + - $ref: '#/components/schemas/ObjectBase' + required: + - name + - type + UpdateObjectRequest: + description: Request schema for object updates + type: object + allOf: + - $ref: '#/components/schemas/ObjectBase' + Demo: + oneOf: + - $ref: "#/components/schemas/ObjectBase" + - $ref: "#/components/schemas/CreateObjectRequest" + - $ref: "#/components/schemas/UpdateObjectRequest" + discriminator: + propertyName: type + diff --git a/tests/test_main.py b/tests/test_main.py index 781a0eb8..d3dd6311 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4875,14 +4875,27 @@ def test_main_disable_warnings(capsys: CaptureFixture): assert captured.err == '' +@pytest.mark.parametrize( + 'input,output', + [ + ( + 'discriminator.yaml', + 'main_openapi_discriminator', + ), + ( + 'discriminator_without_mapping.yaml', + 'main_openapi_discriminator_without_mapping', + ), + ], +) @freeze_time('2019-07-26') -def test_main_openapi_discriminator(): +def test_main_openapi_discriminator(input, output): with TemporaryDirectory() as output_dir: output_file: Path = Path(output_dir) / 'output.py' return_code: Exit = main( [ '--input', - str(OPEN_API_DATA_PATH / 'discriminator.yaml'), + str(OPEN_API_DATA_PATH / input), '--output', str(output_file), '--input-file-type', @@ -4892,9 +4905,7 @@ def test_main_openapi_discriminator(): assert return_code == Exit.OK assert ( output_file.read_text() - == ( - EXPECTED_MAIN_PATH / 'main_openapi_discriminator' / 'output.py' - ).read_text() + == (EXPECTED_MAIN_PATH / output / 'output.py').read_text() )