From 0a704c57761332933e00cd7526548bc15b9f6d33 Mon Sep 17 00:00:00 2001 From: tehcoderer Date: Mon, 5 Feb 2024 16:41:13 -0500 Subject: [PATCH] fix openapi schema fields `to_snake` --- .../core/openbb_core/app/router.py | 22 +++++++------- .../core/openbb_core/provider/registry_map.py | 29 ++++++++++++++----- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index bb626caba4de..6e2e23de9b72 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -46,7 +46,7 @@ def model_t_discriminator(v: Any) -> str: if isinstance(v, dict): return v.get("provider", "openbb") - if isinstance(v, list): + if isinstance(v, list) and v: return model_t_discriminator(v[0]) return getattr(v, "provider", "openbb") @@ -397,7 +397,7 @@ def inject_return_type( Also updates __name__ and __doc__ for API schemas. """ - union_models = [Annotated[None, Tag("openbb")]] + union_models = [Annotated[Union[list, dict], Tag("openbb")]] for provider, return_type in return_map.items(): union_models.append(Annotated[return_type, Tag(provider)]) @@ -406,15 +406,17 @@ def inject_return_type( f"OBBject_{model}", __base__=OBBject, results=( - Annotated[ - Union[tuple(union_models)], # type: ignore - Field( - ..., - description="Serializable results.", - discriminator=Discriminator(model_t_discriminator), - ), + Optional[ + Annotated[ + Union[tuple(union_models)], # type: ignore + Field( + None, + description="Serializable results.", + discriminator=Discriminator(model_t_discriminator), + ), + ] ], - Field(..., description="Serializable results."), + Field(None, description="Serializable results."), ), ) diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index ff9b710a3b81..5a2cd8e7c51b 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -113,17 +113,30 @@ def extract_data_model(fetcher: Fetcher, provider_str: str) -> BaseModel: """Extract info (fields and docstring) from fetcher query params or data.""" model: BaseModel = RegistryMap._get_model(fetcher, "data") - class DataModel(model): - model_config = ConfigDict(alias_generator=alias_generators.to_snake) + fields = {} + for field_name, field in model.model_fields.items(): + field.alias_priority = None + fields[field_name] = (field.annotation, field) - provider: Literal[provider_str, "openbb"] = Field( # type: ignore - default=provider_str, - description="The data provider for the data.", - exclude=True, - ) + fields.pop("provider", None) return create_model( - model.__name__, __base__=DataModel, __module__=model.__module__ + model.__name__.replace("Data", ""), + __doc__=model.__doc__, + __config__=ConfigDict( + extra="allow", + alias_generator=alias_generators.to_snake, + populate_by_name=True, + ), + provider=( + Literal[provider_str, "openbb"], # type: ignore + Field( + default=provider_str, + description="The data provider for the data.", + exclude=True, + ), + ), + **fields, ) @staticmethod