Skip to content

Commit

Permalink
Make Pro Fetchers async by replacing ThreadPoolExecutor with `async…
Browse files Browse the repository at this point in the history
…io.gather` (OpenBB-finance#5822)

* init

* Update helpers.py

* filter URL on raise

* polygon, fmp done, update cassettes

* updates

* cleanup

* revert tests for now

* Update discovery_filings.py

* Update helpers.py

* add `async` overload to `Fetcher.extract_data`

* update merged fred models

* tests, improve fred series

* Update insider_trading.py

* insider

* tests

* intrinio

* Update equity_historical.py

* implement abstract async extract on Fetcher, with simple switch

* cleanup

* Update helpers.py

* Update helpers.py

* fix fmp historical deprecated urls

* Revert "fix fmp historical deprecated urls"

This reverts commit 1a760a2.

* fix fmp `index.market` 1day url missing query params

* Update test_fmp_market_indices_fetcher.yaml

* field fixes

* `extract_data_async` -> `aextract_data`, `async_request` -> `amake_request`

* `async_requests` -> `amake_requests`

* Update company_news.py

* add sort to index `integration` params

* add `client.py` tests

* Update equity_search.py

* Add lost v in platform api version string

* Resolve linter errors here and there

---------

Co-authored-by: Theodore Aptekarev <[email protected]>
  • Loading branch information
tehcoderer and piiq authored Dec 8, 2023
1 parent 40c0c9c commit da7886d
Show file tree
Hide file tree
Showing 119 changed files with 4,912 additions and 3,354 deletions.
51 changes: 28 additions & 23 deletions openbb_platform/core/openbb_core/api/rest_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""REST API for the OpenBB Platform."""
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -17,6 +18,31 @@

system = SystemService().system_settings


@asynccontextmanager
async def lifespan(_: FastAPI):
"""Startup event."""
auth = "ENABLED" if Env().API_AUTH else "DISABLED"
banner = rf"""
███╗
█████████████████╔══█████████████████╗ OpenBB Platform v{system.version}
███╔══════════███║ ███╔══════════███║
█████████████████║ █████████████████║ Authentication: {auth}
╚═════════════███║ ███╔═════════════╝
██████████████║ ██████████████╗
███╔═══════███║ ███╔═══════███║
██████████████║ ██████████████║
╚═════════════╝ ╚═════════════╝
Investment research for everyone, anywhere.
https://my.openbb.co/app/platform
"""
logger.info(banner)
yield


app = FastAPI(
title=system.api_settings.title,
description=system.api_settings.description,
Expand All @@ -38,7 +64,9 @@
}
for s in system.api_settings.servers
],
lifespan=lifespan,
)

app.add_middleware(
CORSMiddleware,
allow_origins=system.api_settings.cors.allow_origins,
Expand All @@ -54,29 +82,6 @@
)


@app.on_event("startup")
async def startup():
"""Startup event."""
auth = "ENABLED" if Env().API_AUTH else "DISABLED"
banner = rf"""
███╗
█████████████████╔══█████████████████╗ OpenBB Platform v{system.version}
███╔══════════███║ ███╔══════════███║
█████████████████║ █████████████████║ Authentication: {auth}
╚═════════════███║ ███╔═════════════╝
██████████████║ ██████████████╗
███╔═══════███║ ███╔═══════███║
██████████████║ ██████████████║
╚═════════════╝ ╚═════════════╝
Investment research for everyone, anywhere.
https://my.openbb.co/app/platform
"""
logger.info(banner)


@app.exception_handler(Exception)
async def api_exception_handler(_: Request, exc: Exception):
"""Exception handler for all other exceptions."""
Expand Down
18 changes: 17 additions & 1 deletion openbb_platform/core/openbb_core/provider/abstract/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,32 @@ def transform_query(params: Dict[str, Any]) -> Q:
"""Transform the params to the provider-specific query."""
raise NotImplementedError

@staticmethod
async def aextract_data(query: Q, credentials: Optional[Dict[str, str]]) -> Any:
"""Asynchronously extract the data from the provider."""

@staticmethod
def extract_data(query: Q, credentials: Optional[Dict[str, str]]) -> Any:
"""Extract the data from the provider."""
raise NotImplementedError

@staticmethod
def transform_data(query: Q, data: Any, **kwargs) -> R:
"""Transform the provider-specific data."""
raise NotImplementedError

def __init_subclass__(cls, **kwargs):
"""Initialize the subclass."""
super().__init_subclass__(**kwargs)

if cls.aextract_data != Fetcher.aextract_data:
cls.extract_data = cls.aextract_data
elif cls.extract_data == Fetcher.extract_data:
raise NotImplementedError(
"Fetcher subclass must implement either extract_data or aextract_data"
" method. If both are implemented, aextract_data will be used as the"
" default."
)

@classmethod
async def fetch_data(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def upper_symbol(cls, v: Union[str, List[str], Set[str]]):
return v.upper()
return ",".join([symbol.upper() for symbol in list(v)]) if v else None

@field_validator("date", mode="before", check_fields=False)
@field_validator("date", "filing_date", mode="before", check_fields=False)
@classmethod
def convert_date(cls, v: str):
"""Convert date to date type."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class IncomeStatementGrowthQueryParams(QueryParams):
)

@field_validator("symbol", mode="before", check_fields=False)
@classmethod
def upper_symbol(cls, v: Union[str, List[str], Set[str]]):
"""Convert symbol to uppercase."""
if isinstance(v, str):
Expand Down Expand Up @@ -111,9 +112,9 @@ class IncomeStatementGrowthData(Data):
)

@field_validator("symbol", mode="before", check_fields=False)
@classmethod
def upper_symbol(cls, v: Union[str, List[str], Set[str]]):
"""Convert symbol to uppercase."""
if isinstance(v, str):
return v.upper()
return ",".join([symbol.upper() for symbol in list(v)]) if v else None
return ",".join([symbol.upper() for symbol in list(v)]) if v else None
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class InsiderTradingData(Data):
default=None,
description="Acquisition or disposition of the insider trading.",
)
security_type: str = Field(description="Security type of the insider trading.")
security_type: Optional[str] = Field(
default=None, description="Security type of the insider trading."
)
securities_owned: Optional[float] = Field(
default=None, description="Number of securities owned in the insider trading."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class KeyExecutivesQueryParams(QueryParams):
symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))

@field_validator("symbol", mode="before", check_fields=False)
@classmethod
def upper_symbol(cls, v: Union[str, List[str], Set[str]]):
"""Convert symbol to uppercase."""
if isinstance(v, str):
Expand All @@ -30,7 +31,9 @@ class KeyExecutivesData(Data):
pay: Optional[ForceInt] = Field(
default=None, description="Pay of the key executive."
)
currency_pay: str = Field(description="Currency of the pay.")
currency_pay: Optional[str] = Field(
default=None, description="Currency of the pay."
)
gender: Optional[str] = Field(
default=None, description="Gender of the key executive."
)
Expand Down
132 changes: 132 additions & 0 deletions openbb_platform/core/openbb_core/provider/utils/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Aiohttp client."""
# pylint: disable=protected-access,invalid-overridden-method
import asyncio
import random
import re
import warnings
import zlib
from typing import Any, Dict, Type, Union

import aiohttp
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict

FILTER_QUERY_REGEX = r".*key.*|.*token.*|.*auth.*|(c$)"


def obfuscate(params: Union[CIMultiDict[str], MultiDict[str]]) -> Dict[str, Any]:
"""Obfuscate sensitive information."""
return {
param: "********" if re.match(FILTER_QUERY_REGEX, param, re.IGNORECASE) else val
for param, val in params.items()
}


def get_user_agent() -> str:
"""Get a not very random user agent."""
user_agent_strings = [
"Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:86.1) Gecko/20100101 Firefox/86.1",
"Mozilla/5.0 (Windows NT 6.1; WOW64; rv:86.1) Gecko/20100101 Firefox/86.1",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.10; rv:82.1) Gecko/20100101 Firefox/82.1",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:86.0) Gecko/20100101 Firefox/86.0",
"Mozilla/5.0 (Windows NT 10.0; WOW64; rv:86.0) Gecko/20100101 Firefox/86.0",
"Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:83.0) Gecko/20100101 Firefox/83.0",
"Mozilla/5.0 (Windows NT 6.1; WOW64; rv:84.0) Gecko/20100101 Firefox/84.0",
]

return random.choice(user_agent_strings) # nosec # noqa: S311


class ClientResponse(aiohttp.ClientResponse):
"""Client response class."""

def __init__(self, *args, **kwargs):
kwargs["request_info"] = self.obfuscate_request_info(kwargs["request_info"])
super().__init__(*args, **kwargs)

@classmethod
def obfuscate_request_info(
cls, request_info: aiohttp.RequestInfo
) -> aiohttp.RequestInfo:
"""Remove sensitive information from request info."""
query = obfuscate(request_info.url.query.copy())
headers = CIMultiDictProxy(CIMultiDict(obfuscate(request_info.headers.copy())))
url = request_info.url.with_query(query)

return aiohttp.RequestInfo(url, request_info.method, headers, url)

async def json(self, **kwargs) -> Union[dict, list]:
"""Return the json response."""
return await super().json(**kwargs)


class ClientSession(aiohttp.ClientSession):
_response_class: Type[ClientResponse]
_session: "ClientSession"

def __init__(self, *args, **kwargs):
kwargs["connector"] = kwargs.get(
"connector", aiohttp.TCPConnector(ttl_dns_cache=300)
)
kwargs["response_class"] = kwargs.get("response_class", ClientResponse)
kwargs["auto_decompress"] = kwargs.get("auto_decompress", False)

super().__init__(*args, **kwargs)

# pylint: disable=unused-argument
def __del__(self, _warnings: Any = warnings) -> None:
"""Close the session."""
if not self.closed:
asyncio.create_task(self.close())

async def get(self, url: str, **kwargs) -> ClientResponse: # type: ignore
"""Send GET request."""
return await self.request("GET", url, **kwargs)

async def post(self, url: str, **kwargs) -> ClientResponse: # type: ignore
"""Send POST request."""
return await self.request("POST", url, **kwargs)

async def get_json(self, url: str, **kwargs) -> Union[dict, list]:
"""Send GET request and return json."""
response = await self.request("GET", url, **kwargs)
return await response.json()

async def get_one(self, url: str, **kwargs) -> Dict[str, Any]:
"""Send GET request and return first item in json if list."""
response = await self.request("GET", url, **kwargs)
data = await response.json()

if isinstance(data, list):
return data[0]

return data

async def request( # type: ignore
self, *args, raise_for_status: bool = False, **kwargs
) -> ClientResponse:
"""Send request."""
kwargs["headers"] = kwargs.get(
"headers",
# Default headers, makes sure we accept gzip
{
"Accept": "application/json",
"Accept-Encoding": "gzip, deflate",
"Connection": "keep-alive",
},
)

if kwargs["headers"].get("User-Agent", None) is None:
kwargs["headers"]["User-Agent"] = get_user_agent()

response = await super().request(*args, **kwargs)

if raise_for_status:
response.raise_for_status()

encoding = response.headers.get("Content-Encoding", "")
if encoding in ("gzip", "deflate") and not self.auto_decompress:
response_body = await response.read()
wbits = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS
response._body = zlib.decompress(response_body, wbits)

return response # type: ignore
Loading

0 comments on commit da7886d

Please sign in to comment.