diff --git a/skynet/env.py b/skynet/env.py index e140630..8ecafb1 100644 --- a/skynet/env.py +++ b/skynet/env.py @@ -95,3 +95,8 @@ def tobool(val: str | None): # load balancing enable_haproxy_agent = tobool(os.environ.get('ENABLE_HAPROXY_AGENT')) + +# testing +echo_requests_base_url = os.environ.get('ECHO_REQUESTS_BASE_URL') +echo_requests_percent = int(os.environ.get('ECHO_REQUESTS_PERCENT', 100)) +echo_requests_token = os.environ.get('ECHO_REQUESTS_TOKEN') diff --git a/skynet/http_client.py b/skynet/http_client.py index 6e704f3..4d6bbdf 100644 --- a/skynet/http_client.py +++ b/skynet/http_client.py @@ -27,4 +27,10 @@ async def get(url, type='json'): return await response.text() +async def post(url, **kwargs): + session = _get_session() + async with session.post(url, **kwargs) as response: + return await response.json() + + __all__ = ['get'] diff --git a/skynet/modules/ttt/summaries/app.py b/skynet/modules/ttt/summaries/app.py index 7bf43fc..c50aec0 100644 --- a/skynet/modules/ttt/summaries/app.py +++ b/skynet/modules/ttt/summaries/app.py @@ -1,6 +1,10 @@ +import random +from fastapi import Request from fastapi_versionizer.versionizer import Versionizer +from skynet import http_client from skynet.auth.openai import setup_credentials +from skynet.env import echo_requests_base_url, echo_requests_percent, echo_requests_token from skynet.logs import get_logger from skynet.modules.ttt.openai_api.app import destroy as destroy_openai_api, initialize as initialize_openai_api from skynet.utils import create_app @@ -16,6 +20,23 @@ app = create_app() app.include_router(v1_router) +if echo_requests_base_url: + + @app.middleware("http") + async def echo_requests(request: Request, call_next): + if request.method == 'POST': + counter = random.randrange(1, 101) + + if counter <= echo_requests_percent: + await http_client.post( + f'{echo_requests_base_url}/{request.url.path}', + headers={'Authorization': f'Bearer {echo_requests_token}'}, + json=await request.json(), + ) + + return await call_next(request) + + Versionizer(app=app, prefix_format='/v{major}', sort_routes=True).versionize()