Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix openapi schema fields to_snake, tag for empty list/dict #6036

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions openbb_platform/core/openbb_core/app/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def model_t_discriminator(v: Any) -> str:
if isinstance(v, dict):
return v.get("provider", "openbb")

if isinstance(v, list):
if isinstance(v, list) and v:
return model_t_discriminator(v[0])

return getattr(v, "provider", "openbb")
Expand Down Expand Up @@ -397,7 +397,7 @@ def inject_return_type(
Also updates __name__ and __doc__ for API schemas.
"""

union_models = [Annotated[None, Tag("openbb")]]
union_models = [Annotated[Union[list, dict], Tag("openbb")]]

for provider, return_type in return_map.items():
union_models.append(Annotated[return_type, Tag(provider)])
Expand All @@ -406,15 +406,17 @@ def inject_return_type(
f"OBBject_{model}",
__base__=OBBject,
results=(
Annotated[
Union[tuple(union_models)], # type: ignore
Field(
...,
description="Serializable results.",
discriminator=Discriminator(model_t_discriminator),
),
Optional[
Annotated[
Union[tuple(union_models)], # type: ignore
Field(
None,
description="Serializable results.",
discriminator=Discriminator(model_t_discriminator),
),
]
],
Field(..., description="Serializable results."),
Field(None, description="Serializable results."),
),
)

Expand Down
29 changes: 21 additions & 8 deletions openbb_platform/core/openbb_core/provider/registry_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,30 @@ def extract_data_model(fetcher: Fetcher, provider_str: str) -> BaseModel:
"""Extract info (fields and docstring) from fetcher query params or data."""
model: BaseModel = RegistryMap._get_model(fetcher, "data")

class DataModel(model):
model_config = ConfigDict(alias_generator=alias_generators.to_snake)
fields = {}
for field_name, field in model.model_fields.items():
field.alias_priority = None
fields[field_name] = (field.annotation, field)

provider: Literal[provider_str, "openbb"] = Field( # type: ignore
default=provider_str,
description="The data provider for the data.",
exclude=True,
)
fields.pop("provider", None)

return create_model(
model.__name__, __base__=DataModel, __module__=model.__module__
model.__name__.replace("Data", ""),
__doc__=model.__doc__,
__config__=ConfigDict(
extra="allow",
alias_generator=alias_generators.to_snake,
populate_by_name=True,
),
provider=(
Literal[provider_str, "openbb"], # type: ignore
Field(
default=provider_str,
description="The data provider for the data.",
exclude=True,
),
),
**fields,
)

@staticmethod
Expand Down
Loading