Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Custom choices #6169

Merged
merged 10 commits into from
Mar 7, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass
from typing import Dict, Optional

from typing_extensions import LiteralString

# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):
slots_true = {"slots": True}
Expand All @@ -24,3 +26,10 @@ class OpenBBCustomParameter(BaseMetadata):
"""Custom parameter for OpenBB."""

description: Optional[str] = None


@dataclass(frozen=True, **slots_true)
class OpenBBCustomChoices(BaseMetadata):
"""Custom choices for OpenBB."""

choices: Optional[LiteralString] = None
48 changes: 32 additions & 16 deletions openbb_platform/core/openbb_core/app/static/package_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from typing_extensions import Annotated, _AnnotatedAlias

from openbb_core.app.extension_loader import ExtensionLoader, OpenBBGroups
from openbb_core.app.model.custom_parameter import OpenBBCustomParameter
from openbb_core.app.model.custom_parameter import (
OpenBBCustomChoices,
OpenBBCustomParameter,
)
from openbb_core.app.provider_interface import ProviderInterface
from openbb_core.app.router import CommandMap, RouterLoader
from openbb_core.app.static.utils.console import Console
Expand Down Expand Up @@ -330,9 +333,7 @@ def build(cls, path: str) -> str:
hint_type_list = cls.get_path_hint_type_list(path=path)
code = "from openbb_core.app.static.container import Container"
code += "\nfrom openbb_core.app.model.obbject import OBBject"
code += (
"\nfrom openbb_core.app.model.custom_parameter import OpenBBCustomParameter"
)
code += "\nfrom openbb_core.app.model.custom_parameter import OpenBBCustomParameter, OpenBBCustomChoices"

# These imports were not detected before build, so we add them manually and
# ruff --fix the resulting code to remove unused imports.
Expand Down Expand Up @@ -500,7 +501,10 @@ def get_extra(field: FieldInfo) -> dict:
"""Get json schema extra."""
field_default = getattr(field, "default", None)
if field_default:
return getattr(field_default, "json_schema_extra", {})
# Getting json_schema_extra without changing the original dict
json_schema_extra = getattr(field_default, "json_schema_extra", {}).copy()
json_schema_extra.pop("choices", None)
return json_schema_extra
hjoaquim marked this conversation as resolved.
Show resolved Hide resolved
return {}

@staticmethod
Expand Down Expand Up @@ -604,10 +608,10 @@ def format_params(
return MethodDefinition.reorder_params(params=formatted)

@staticmethod
def add_field_descriptions(
def add_field_custom_annotations(
od: OrderedDict[str, Parameter], model_name: Optional[str] = None
):
"""Add the field description to the param signature."""
"""Add the field custom description and choices to the param signature as annotations."""
if model_name:
available_fields: Dict[str, Field] = (
ProviderInterface().params[model_name]["standard"].__dataclass_fields__
Expand All @@ -617,16 +621,28 @@ def add_field_descriptions(
if param not in available_fields:
continue

field = available_fields[param]
field_default = available_fields[param].default

new_value = value.replace(
annotation=Annotated[
value.annotation,
OpenBBCustomParameter(
description=getattr(field.default, "description", "")
),
],
choices = getattr(field_default, "json_schema_extra", {}).get(
"choices", []
)
description = getattr(field_default, "description", "")

if choices:
new_value = value.replace(
annotation=Annotated[
value.annotation,
OpenBBCustomParameter(description=description),
OpenBBCustomChoices(choices=choices),
],
)
else:
new_value = value.replace(
annotation=Annotated[
value.annotation,
OpenBBCustomParameter(description=description),
],
)

od[param] = new_value

Expand Down Expand Up @@ -667,7 +683,7 @@ def build_command_method_signature(
model_name: Optional[str] = None,
) -> str:
"""Build the command method signature."""
MethodDefinition.add_field_descriptions(
MethodDefinition.add_field_custom_annotations(
od=formatted_params, model_name=model_name
) # this modified `od` in place
func_params = MethodDefinition.build_func_params(formatted_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@
class ConsumerPriceIndexQueryParams(QueryParams):
"""CPI Query."""

country: str = Field(description=QUERY_DESCRIPTIONS.get("country"))
country: str = Field(
description=QUERY_DESCRIPTIONS.get("country"),
choices=CPI_COUNTRIES, # type: ignore
)
units: CPI_UNITS = Field(
default="growth_same",
description=QUERY_DESCRIPTIONS.get("units", "")
Expand Down
58 changes: 57 additions & 1 deletion openbb_platform/openbb/package/economy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import datetime
from typing import List, Literal, Optional, Union

from openbb_core.app.model.custom_parameter import OpenBBCustomParameter
from openbb_core.app.model.custom_parameter import (
OpenBBCustomChoices,
OpenBBCustomParameter,
)
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.static.container import Container
from openbb_core.app.static.utils.decorators import exception_handler, validate
Expand Down Expand Up @@ -240,6 +243,59 @@ def cpi(
OpenBBCustomParameter(
description="The country to get data. Multiple items allowed for provider(s): fred."
),
OpenBBCustomChoices(
choices=[
"australia",
"austria",
"belgium",
"brazil",
"bulgaria",
"canada",
"chile",
"china",
"croatia",
"cyprus",
"czech_republic",
"denmark",
"estonia",
"euro_area",
"finland",
"france",
"germany",
"greece",
"hungary",
"iceland",
"india",
"indonesia",
"ireland",
"israel",
"italy",
"japan",
"korea",
"latvia",
"lithuania",
"luxembourg",
"malta",
"mexico",
"netherlands",
"new_zealand",
"norway",
"poland",
"portugal",
"romania",
"russian_federation",
"slovak_republic",
"slovakia",
"slovenia",
"south_africa",
"spain",
"sweden",
"switzerland",
"turkey",
"united_kingdom",
"united_states",
]
),
],
units: Annotated[
Literal["growth_previous", "growth_same", "index_2015"],
Expand Down
Loading