Skip to content

Commit

Permalink
Improve discriminator (#1666)
Browse files Browse the repository at this point in the history
* Improve discriminator

* Improve discriminator

* Fix unittest

* Add unittest

* Fix coverage
  • Loading branch information
koxudaxi authored Nov 8, 2023
1 parent bb9a204 commit 979444c
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 11 deletions.
79 changes: 78 additions & 1 deletion datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
from pydantic import BaseModel

from datamodel_code_generator.format import CodeFormatter, PythonVersion
from datamodel_code_generator.imports import IMPORT_ANNOTATIONS, Import, Imports
from datamodel_code_generator.imports import (
IMPORT_ANNOTATIONS,
Import,
Imports,
)
from datamodel_code_generator.model import pydantic as pydantic_model
from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2
from datamodel_code_generator.model.base import (
ALL_MODEL,
UNDEFINED,
Expand Down Expand Up @@ -722,6 +727,77 @@ def __extract_inherited_enum(cls, models: List[DataModel]) -> None:
)
models.remove(model)

def __apply_discriminator_type(
self,
models: List[DataModel],
imports: Imports,
) -> None:
for model in models:
for field in model.fields:
discriminator = field.extras.get('discriminator')
if not discriminator or not isinstance(discriminator, dict):
continue
property_name = discriminator.get('propertyName')
if not property_name: # pragma: no cover
continue
mapping = discriminator.get('mapping', {})
for data_type in field.data_type.data_types:
if not data_type.reference: # pragma: no cover
continue
discriminator_model = data_type.reference.source
if not isinstance( # pragma: no cover
discriminator_model,
(pydantic_model.BaseModel, pydantic_model_v2.BaseModel),
):
continue # pragma: no cover
type_name = None
if mapping:
for name, path in mapping.items():
if (
discriminator_model.path.split('#/')[-1]
!= path.split('#/')[-1]
):
# TODO: support external reference
continue
type_name = name
else:
type_name = discriminator_model.path.split('/')[-1]
if not type_name: # pragma: no cover
raise RuntimeError(
f'Discriminator type is not found. {data_type.reference.path}'
)
has_one_literal = False
for discriminator_field in discriminator_model.fields:
if (
discriminator_field.original_name
or discriminator_field.name
) != property_name:
continue
literals = discriminator_field.data_type.literals
if len(literals) == 1 and literals[0] == type_name:
has_one_literal = True
continue
for (
field_data_type
) in discriminator_field.data_type.all_data_types:
if field_data_type.reference: # pragma: no cover
field_data_type.remove_reference()
discriminator_field.data_type = self.data_type(
literals=[type_name]
)
discriminator_field.data_type.parent = discriminator_field
discriminator_field.required = True
imports.append(discriminator_field.imports)
has_one_literal = True
if not has_one_literal:
discriminator_model.fields.append(
self.data_model_field_type(
name=property_name,
data_type=self.data_type(literals=[type_name]),
required=True,
)
)

@classmethod
def _create_set_from_list(cls, data_type: DataType) -> Optional[DataType]:
if data_type.is_list:
Expand Down Expand Up @@ -1155,6 +1231,7 @@ class Processed(NamedTuple):
self.__override_required_field(models)
self.__sort_models(models, imports)
self.__set_one_literal_on_default(models)
self.__apply_discriminator_type(models, imports)

processed_models.append(
Processed(module, models, init, imports, scoped_model_resolver)
Expand Down
8 changes: 5 additions & 3 deletions tests/data/expected/main/main_openapi_discriminator/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,27 @@
from typing import Optional, Union

from pydantic import BaseModel, Field
from typing_extensions import Literal


class Type(Enum):
my_first_object = 'my_first_object'
my_second_object = 'my_second_object'
my_third_object = 'my_third_object'


class ObjectBase(BaseModel):
name: Optional[str] = Field(None, description='Name of the object')
type: Optional[Type] = Field(None, description='Object type')
type: Literal['type1'] = Field(..., description='Object type')


class CreateObjectRequest(ObjectBase):
name: str = Field(..., description='Name of the object')
type: Type = Field(..., description='Object type')
type: Literal['type2'] = Field(..., description='Object type')


class UpdateObjectRequest(ObjectBase):
pass
type: Literal['type3']


class Demo(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# generated by datamodel-codegen:
# filename: discriminator_without_mapping.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import Optional, Union

from pydantic import BaseModel, Field
from typing_extensions import Literal


class Type(Enum):
my_first_object = 'my_first_object'
my_second_object = 'my_second_object'
my_third_object = 'my_third_object'


class ObjectBase(BaseModel):
name: Optional[str] = Field(None, description='Name of the object')
type: Literal['ObjectBase'] = Field(..., description='Object type')


class CreateObjectRequest(ObjectBase):
name: str = Field(..., description='Name of the object')
type: Literal['CreateObjectRequest'] = Field(..., description='Object type')


class UpdateObjectRequest(ObjectBase):
type: Literal['UpdateObjectRequest']


class Demo(BaseModel):
__root__: Union[ObjectBase, CreateObjectRequest, UpdateObjectRequest] = Field(
..., discriminator='type'
)
7 changes: 5 additions & 2 deletions tests/data/openapi/discriminator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ components:
enum:
- my_first_object
- my_second_object
- my_third_object
CreateObjectRequest:
description: Request schema for object creation
type: object
Expand All @@ -35,5 +36,7 @@ components:
discriminator:
propertyName: type
mapping:
type1: "#/components/schemas/Schema1"
type2: "#/components/schemas/Schema2"
type1: "#/components/schemas/ObjectBase"
type2: "#/components/schemas/CreateObjectRequest"
type3: "#/components/schemas/UpdateObjectRequest"

38 changes: 38 additions & 0 deletions tests/data/openapi/discriminator_without_mapping.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
openapi: "3.0.0"
components:
schemas:
ObjectBase:
description: Object schema
type: object
properties:
name:
description: Name of the object
type: string
type:
description: Object type
type: string
enum:
- my_first_object
- my_second_object
- my_third_object
CreateObjectRequest:
description: Request schema for object creation
type: object
allOf:
- $ref: '#/components/schemas/ObjectBase'
required:
- name
- type
UpdateObjectRequest:
description: Request schema for object updates
type: object
allOf:
- $ref: '#/components/schemas/ObjectBase'
Demo:
oneOf:
- $ref: "#/components/schemas/ObjectBase"
- $ref: "#/components/schemas/CreateObjectRequest"
- $ref: "#/components/schemas/UpdateObjectRequest"
discriminator:
propertyName: type

21 changes: 16 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4875,14 +4875,27 @@ def test_main_disable_warnings(capsys: CaptureFixture):
assert captured.err == ''


@pytest.mark.parametrize(
'input,output',
[
(
'discriminator.yaml',
'main_openapi_discriminator',
),
(
'discriminator_without_mapping.yaml',
'main_openapi_discriminator_without_mapping',
),
],
)
@freeze_time('2019-07-26')
def test_main_openapi_discriminator():
def test_main_openapi_discriminator(input, output):
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(OPEN_API_DATA_PATH / 'discriminator.yaml'),
str(OPEN_API_DATA_PATH / input),
'--output',
str(output_file),
'--input-file-type',
Expand All @@ -4892,9 +4905,7 @@ def test_main_openapi_discriminator():
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_MAIN_PATH / 'main_openapi_discriminator' / 'output.py'
).read_text()
== (EXPECTED_MAIN_PATH / output / 'output.py').read_text()
)


Expand Down

0 comments on commit 979444c

Please sign in to comment.