Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] - Fix economic calendar country #6059

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class ConsumerPriceIndexQueryParams(QueryParams):
def validate_country(cls, c: str): # pylint: disable=E0213
"""Validate country."""
result = []
values = c.split(",")
values = c.replace(" ", "_").split(",")
for v in values:
check_item(v.lower(), CPI_COUNTRIES)
result.append(v.lower())
Expand Down
6 changes: 3 additions & 3 deletions openbb_platform/openbb/package/economy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def calendar(
The provider to use for the query, by default None.
If None, the provider specified in defaults is selected or 'fmp' if there is
no default.
country : Optional[Union[str, List[str]]]
Country of the event (provider: tradingeconomics)
country : Optional[str]
Country of the event. (provider: tradingeconomics)
importance : Optional[Literal['Low', 'Medium', 'High']]
Importance of the event. (provider: tradingeconomics)
group : Optional[Literal['interest rate', 'inflation', 'bonds', 'consumer', 'gdp', 'government', 'housing', 'labour', 'markets', 'money', 'prices', 'trade', 'business']]
Expand All @@ -69,7 +69,7 @@ def calendar(
Returns
-------
OBBject
results : List[EconomicCalendar]
results : EconomicCalendar
Serializable results.
provider : Optional[Literal['fmp', 'tradingeconomics']]
Provider name.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Trading Economics Economic Calendar Model."""

from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Set, Union
from typing import Any, Dict, List, Literal, Optional, Union

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.economic_calendar import (
EconomicCalendarData,
EconomicCalendarQueryParams,
)
from openbb_core.provider.utils.helpers import ClientResponse, amake_request
from openbb_core.provider.utils.helpers import ClientResponse, amake_request, check_item
from openbb_tradingeconomics.utils import url_generator
from openbb_tradingeconomics.utils.countries import country_list
from openbb_tradingeconomics.utils.countries import COUNTRIES
from pandas import to_datetime
from pydantic import Field, field_validator

Expand Down Expand Up @@ -40,21 +40,22 @@ class TEEconomicCalendarQueryParams(EconomicCalendarQueryParams):
"""

# TODO: Probably want to figure out the list we can use.
country: Optional[Union[str, List[str]]] = Field(
default=None, description="Country of the event"
)
country: Optional[str] = Field(default=None, description="Country of the event.")
importance: Optional[IMPORTANCE] = Field(
default=None, description="Importance of the event."
)
group: Optional[GROUPS] = Field(default=None, description="Grouping of events")

@field_validator("country", mode="before", check_fields=False)
@classmethod
def validate_country(cls, v: Union[str, List[str], Set[str]]):
"""Validate the country input."""
if isinstance(v, str):
return v.lower().replace(" ", "_")
return ",".join([country.lower().replace(" ", "_") for country in list(v)])
def validate_country(cls, c: str): # pylint: disable=E0213
"""Validate country."""
result = []
values = c.replace(" ", "_").split(",")
for v in values:
check_item(v.lower(), COUNTRIES)
result.append(v.lower())
return ",".join(result)

@field_validator("importance")
@classmethod
Expand Down Expand Up @@ -111,29 +112,18 @@ async def aextract_data(
query: TEEconomicCalendarQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> List[Dict]:
) -> Union[dict, List[dict]]:
"""Return the raw data from the TE endpoint."""
api_key = credentials.get("tradingeconomics_api_key") if credentials else ""

if query.country is not None:
country = (
query.country.split(",") if "," in query.country else query.country
)
country = [country] if isinstance(country, str) else country

for c in country:
if c.replace("_", " ").lower() not in country_list:
raise ValueError(f"{c} is not a valid country")
query.country = country

url = url_generator.generate_url(query)
if not url:
raise RuntimeError(
"No url generated. Check combination of input parameters."
)
url = f"{url}{api_key}"

async def callback(response: ClientResponse, _: Any) -> List[Dict]:
async def callback(response: ClientResponse, _: Any) -> Union[dict, List[dict]]:
"""Return the response."""
if response.status != 200:
raise RuntimeError(f"Error in TE request -> {await response.text()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@
],
}

country_list = list(
set([item.lower() for sublist in country_dict.values() for item in sublist])
COUNTRIES = list(
{
item.lower().replace(" ", "_")
for sublist in country_dict.values()
for item in sublist
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def generate_url(in_query):

# Handle the formatting for the api
if "country" in query:
country = quote(",".join(query["country"]).replace("_", " "))
country = quote(query["country"].replace("_", " "))
if "group" in query:
group = quote(query["group"])

Expand Down
Loading