Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: loadbalance stream based on response #6122

Merged
merged 11 commits into from
Dec 6, 2023
53 changes: 27 additions & 26 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import itertools
import json
from typing import TYPE_CHECKING, AsyncIterator, Dict

from aiohttp.client import _RequestContextManager

from jina.enums import ProtocolType
from jina.helper import get_full_version
from jina.proto import jina_pb2
Expand Down Expand Up @@ -157,18 +160,23 @@ async def _load_balance(self, request):
try:
async with aiohttp.ClientSession() as session:

if request.method == 'GET':
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')
Comment on lines -164 to -167
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the original implementation, i have this idea:
it looks like the original logic was only to write a debug log which is not useful at all for production application. Can we just act as a pure proxy here for performance consideration? Something like:

async with session.request(request.method, data=request.iter_any(), **request_kwargs) as response:
    ....

@NarekA @JoanFM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, if I try to pass the content in any other way besides the json field, I get an error here. I've tried everything at this point, if you can get this to work, I am interested.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what error do you see?


async with session.get(
url=target_url, **request_kwargs
) as response:
request_kwargs = {
'headers': request.headers,
'params': request.query,
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code did not forward the headers & query, which makes me wonder if it's intentional?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was not intentional

try:
payload = await request.read()
if payload:
request_kwargs['json'] = json.loads(payload.decode())
except Exception:
self.logger.debug('No JSON payload found in request')

async with session.request(
request.method, target_url, **request_kwargs
) as response:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarekA why not directly use the request method from the session?, it's already wrapping the context manager for you.

    def request(
        self, method: str, url: StrOrURL, **kwargs: Any
    ) -> "_RequestContextManager":
        """Perform HTTP request."""
        return _RequestContextManager(self._request(method, url, **kwargs))
        ```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I missed this, will fix.

# Looking for application/octet-stream, text/event-stream, text/stream
if response.content_type.endswith('stream'):

# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
Expand All @@ -185,20 +193,13 @@ async def _load_balance(self, request):
# Close the stream response once all chunks are sent
await stream_response.write_eof()
return stream_response

elif request.method == 'POST':
d = await request.read()
import json

async with session.post(
url=target_url, json=json.loads(d.decode())
) as response:
content = await response.read()
return web.Response(
body=content,
status=response.status,
content_type=response.content_type,
)
content = await response.read()
return web.Response(
body=content,
status=response.status,
content_type=response.content_type,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the header be added here as well?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also by looking at the original GET implementation, I wonder if we just need to use StreamResponse all the way. Basically the load balancer just stream whatever it receives out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this would be a good idea

headers=response.headers,
)
except aiohttp.ClientError as e:
return web.Response(text=f'Error: {str(e)}', status=500)

Expand Down
6 changes: 2 additions & 4 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,17 @@ async def test_issue_6090_get_params(streaming_deployment):

docs = []
url = (
f"htto://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
f"http://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
)
async with aiohttp.ClientSession() as session:

async with session.get(url) as resp:
async for chunk in resp.content.iter_any():
print(chunk)
events = chunk.split(b'event: ')[1:]
for event in events:
if event.startswith(b'update'):
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX:].decode()
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX :].decode()
parsed = SimpleInput.parse_raw(parsed)
print(parsed)
docs.append(parsed)
elif event.startswith(b'end'):
pass
Expand Down