Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the OECD requests #6152

Merged
merged 18 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,24 @@ def extract_data(
query: OECDCLIQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..LI...."
country = "" if query.country == "all" else COUNTRY_TO_CODE_CLI[query.country]

# Note this is only available monthly from OECD
url = f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/{country}.M.LI...."

query_dict = {
k: v
for k, v in query.__dict__.items()
if k not in ["start_date", "end_date"]
}
data = helpers.get_possibly_cached_data(
url, function="economy_composite_leading_indicator"
url, function="economy_composite_leading_indicator", query_dict=query_dict
)

if query.country != "all":
data = data.query(f"REF_AREA == '{COUNTRY_TO_CODE_CLI[query.country]}'")
data = data.query(f"REF_AREA == '{country}'")

# Filter down
data = data.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]].rename(
Expand All @@ -109,7 +118,7 @@ def extract_data(
@staticmethod
def transform_data(
query: OECDCLIQueryParams,
data: Dict,
data: List[Dict],
**kwargs: Any,
) -> List[OECDCLIData]:
"""Transform the data from the OECD endpoint."""
Expand Down
13 changes: 11 additions & 2 deletions openbb_platform/providers/oecd/openbb_oecd/models/gdp_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
gdp_countries = tuple(constants.COUNTRY_TO_CODE_GDP_FORECAST.keys())
GDPCountriesLiteral = Literal[gdp_countries] # type: ignore

# pylint: disable=unused-argument


class OECDGdpForecastQueryParams(GdpForecastQueryParams):
"""OECD GDP Forecast Query."""
Expand Down Expand Up @@ -67,7 +69,7 @@ def extract_data(
query: OECDGdpForecastQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
units = query.period[0].upper()
_type = "REAL" if query.type == "real" else "NOM"
Expand All @@ -94,11 +96,18 @@ def extract_data(
)
data_df = data_df[data_df["country"] == query.country]
data_df = data_df[["country", "date", "value"]]
data_df["date"] = data_df["date"].apply(helpers.oecd_date_to_python_date)
data_df = data_df[
(data_df["date"] <= query.end_date) & (data_df["date"] >= query.start_date)
]
data_df["date"] = data_df["date"].apply(
lambda x: x.year
) # Validator won't accept datetime.date?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is because the validator is expecting a string representation. Although, because you have already conformed the date field in the extract process, the validator is now redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that too but it accepts str|date so it was unclear what was actually going wrong

return data_df.to_dict(orient="records")

@staticmethod
def transform_data(
query: OECDGdpForecastQueryParams, data: Dict, **kwargs: Any
query: OECDGdpForecastQueryParams, data: List[Dict], **kwargs: Any
) -> List[OECDGdpForecastData]:
"""Transform the data from the OECD endpoint."""
return [OECDGdpForecastData.model_validate(d) for d in data]
18 changes: 14 additions & 4 deletions openbb_platform/providers/oecd/openbb_oecd/models/gdp_nominal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from openbb_oecd.utils import constants, helpers
from pydantic import Field, field_validator

gdp_countries = tuple(constants.COUNTRY_TO_CODE_GDP.keys())
gdp_countries = tuple(constants.COUNTRY_TO_CODE_GDP.keys()) + ("all",)
GDPCountriesLiteral = Literal[gdp_countries] # type: ignore

# pylint: disable=unused-argument


class OECDGdpNominalQueryParams(GdpNominalQueryParams):
"""OECD Nominal GDP Query."""
Expand Down Expand Up @@ -56,7 +58,7 @@ def extract_data(
query: OECDGdpNominalQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
unit = "MLN_USD" if query.units == "usd" else "USD_CAP"
url = (
Expand All @@ -76,13 +78,21 @@ def extract_data(
}
)
data_df["country"] = data_df["country"].map(constants.CODE_TO_COUNTRY_GDP)
data_df = data_df[data_df["country"] == query.country]
if query.country != "all":
data_df = data_df[data_df["country"] == query.country]
data_df = data_df[["country", "date", "value"]]
data_df["date"] = data_df["date"].apply(helpers.oecd_date_to_python_date)
data_df = data_df[
(data_df["date"] <= query.end_date) & (data_df["date"] >= query.start_date)
]
data_df["date"] = data_df["date"].apply(
lambda x: x.year
) # Validator won't accept datetime.date?
return data_df.to_dict(orient="records")

@staticmethod
def transform_data(
query: OECDGdpNominalQueryParams, data: Dict, **kwargs: Any
query: OECDGdpNominalQueryParams, data: List[Dict], **kwargs: Any
) -> List[OECDGdpNominalData]:
"""Transform the data from the OECD endpoint."""
return [OECDGdpNominalData.model_validate(d) for d in data]
11 changes: 9 additions & 2 deletions openbb_platform/providers/oecd/openbb_oecd/models/gdp_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from openbb_oecd.utils import constants, helpers
from pydantic import Field, field_validator

rgdp_countries = tuple(constants.COUNTRY_TO_CODE_RGDP.keys())
rgdp_countries = tuple(constants.COUNTRY_TO_CODE_RGDP.keys()) + ("all",)
RGDPCountriesLiteral = Literal[rgdp_countries] # type: ignore


# pylint: disable=unused-argument
class OECDGdpRealQueryParams(GdpRealQueryParams):
"""OECD Real GDP Query."""

Expand Down Expand Up @@ -83,8 +84,14 @@ def extract_data(
}
)
data_df["country"] = data_df["country"].map(constants.CODE_TO_COUNTRY_RGDP)
data_df = data_df[data_df["country"] == query.country]
if query.country != "all":
data_df = data_df[data_df["country"] == query.country]
data_df = data_df[["country", "date", "value"]]

data_df["date"] = data_df["date"].apply(helpers.oecd_date_to_python_date)
data_df = data_df[
(data_df["date"] <= query.end_date) & (data_df["date"] >= query.start_date)
]
return data_df.to_dict(orient="records")

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ def extract_data(
query: OECDLTIRQueryParams, # pylint: disable=W0613
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country]
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IRLT...."
query_dict = {
k: v
for k, v in query.__dict__.items()
if k not in ["start_date", "end_date"]
}
url = f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/{country}.{frequency}.IRLT...."
data = helpers.get_possibly_cached_data(
url, function="economy_long_term_interest_rate"
url, function="economy_long_term_interest_rate", query_dict=query_dict
)
url_query = f"FREQ=='{frequency}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
Expand All @@ -107,17 +112,15 @@ def extract_data(
)
data["country"] = data["country"].map(CODE_TO_COUNTRY_IR)
data = data.fillna("N/A").replace("N/A", None)
data = data.to_dict(orient="records")

start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data
data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
data = data[
(data["date"] <= query.end_date) & (data["date"] >= query.start_date)
]
return data.to_dict(orient="records")

@staticmethod
def transform_data(
query: OECDLTIRQueryParams, data: Dict, **kwargs: Any
query: OECDLTIRQueryParams, data: List[Dict], **kwargs: Any
) -> List[OECDLTIRData]:
"""Transform the data from the OECD endpoint."""
return [OECDLTIRData.model_validate(d) for d in data]
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,19 @@ def extract_data(
query: OECDSTIRQueryParams, # pylint: disable=W0613
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country]
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IR3TIB...."
query_dict = {
k: v
for k, v in query.__dict__.items()
if k not in ["start_date", "end_date"]
}

url = f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/{country}.{frequency}.IR3TIB...."
data = helpers.get_possibly_cached_data(
url, function="economy_short_term_interest_rate"
url, function="economy_short_term_interest_rate", query_dict=query_dict
)
url_query = f"FREQ=='{frequency}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
Expand All @@ -107,17 +113,16 @@ def extract_data(
)
data["country"] = data["country"].map(CODE_TO_COUNTRY_IR)
data = data.fillna("N/A").replace("N/A", None)
data = data.to_dict(orient="records")

start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))
data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
data = data[
(data["date"] <= query.end_date) & (data["date"] >= query.start_date)
]

return data
return data.to_dict(orient="records")

@staticmethod
def transform_data(
query: OECDSTIRQueryParams, data: Dict, **kwargs: Any
query: OECDSTIRQueryParams, data: List[Dict], **kwargs: Any
) -> List[OECDSTIRData]:
"""Transform the data from the OECD endpoint."""
return [OECDSTIRData.model_validate(d) for d in data]
30 changes: 21 additions & 9 deletions openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def extract_data(
query: OECDUnemploymentQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
sex = {"total": "_T", "male": "M", "female": "F"}[query.sex]
frequency = query.frequency[0].upper()
Expand All @@ -121,8 +121,20 @@ def extract_data(
if query.country == "all"
else COUNTRY_TO_CODE_UNEMPLOYMENT[query.country]
)
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_INDIC,1.0/.UNE_LF........"
data = helpers.get_possibly_cached_data(url, function="economy_unemployment")
# For caching, include this in the key
query_dict = {
k: v
for k, v in query.__dict__.items()
if k not in ["start_date", "end_date"]
}

url = (
f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_INDIC,"
f"1.0/{country}.UNE_LF...{seasonal_adjustment}.{sex}.{age}..."
)
data = helpers.get_possibly_cached_data(
url, function="economy_unemployment", query_dict=query_dict
)
url_query = f"AGE=='{age}' & SEX=='{sex}' & FREQ=='{frequency}' & ADJUSTMENT=='{seasonal_adjustment}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
# Filter down
Expand All @@ -135,17 +147,17 @@ def extract_data(
)
data["country"] = data["country"].map(CODE_TO_COUNTRY_UNEMPLOYMENT)

data = data.to_dict(orient="records")
start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))
data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
data = data[
(data["date"] <= query.end_date) & (data["date"] >= query.start_date)
]

return data
return data.to_dict(orient="records")

# pylint: disable=unused-argument
@staticmethod
def transform_data(
query: OECDUnemploymentQueryParams, data: Dict, **kwargs: Any
query: OECDUnemploymentQueryParams, data: List[Dict], **kwargs: Any
) -> List[OECDUnemploymentData]:
"""Transform the data from the OECD endpoint."""
return [OECDUnemploymentData.model_validate(d) for d in data]
Loading
Loading