Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a couple of bugs in the base64 file_encoding_strategy #398

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions replicate/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def encode_json(
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
if file_encoding_strategy == "base64":
return base64.b64encode(obj.read()).decode("utf-8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lolwut. how did this ever work? did it just... not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It did not.

return base64_encode_file(obj)
else:
return client.files.create(obj).urls["get"]
if HAS_NUMPY:
Expand Down Expand Up @@ -77,9 +77,13 @@ async def async_encode_json(
]
if isinstance(obj, Path):
with obj.open("rb") as file:
return encode_json(file, client, file_encoding_strategy)
return await async_encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
return (await client.files.async_create(obj)).urls["get"]
if file_encoding_strategy == "base64":
# TODO: This should ideally use an async based file reader path.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably isn't too hard if we're prepared to take a dep on aiofile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

return base64_encode_file(obj)
else:
return (await client.files.async_create(obj)).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down
129 changes: 129 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import asyncio
import io
import json
import sys
from email.message import EmailMessage
from email.parser import BytesParser
from email.policy import HTTP
from typing import AsyncIterator, Iterator, Optional, cast

import httpx
Expand Down Expand Up @@ -581,6 +586,130 @@ async def test_run_with_model_error(mock_replicate_api_token):
assert excinfo.value.prediction.status == "failed"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("processing"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
200,
json=_version_with_schema(),
)
)
mock_files_create = router.route(
method="POST",
path="/files",
).mock(
return_value=httpx.Response(
200,
json={
"id": "file1",
"name": "file.png",
"content_type": "image/png",
"size": 10,
"etag": "123",
"checksums": {},
"metadata": {},
"created_at": "",
"expires_at": "",
"urls": {"get": "https://api.replicate.com/files/file.txt"},
},
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
if async_flag:
await client.async_run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
)
else:
client.run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
)

assert mock_predictions_create.called
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
assert (
prediction_payload.get("input", {}).get("file")
== "https://api.replicate.com/files/file.txt"
)

# Validate the Files API request
req = mock_files_create.calls[0].request
body = req.content
content_type = req.headers["Content-Type"]

# Parse the multipart data
parser = BytesParser(EmailMessage, policy=HTTP)
headers = f"Content-Type: {content_type}\n\n".encode()
parsed_message_generator = parser.parsebytes(headers + body).walk()
next(parsed_message_generator) # wrapper
input_file = next(parsed_message_generator)
assert mock_files_create.called
assert input_file.get_content() == b"hello world"
assert input_file.get_content_type() == "application/octet-stream"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("processing"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
200,
json=_version_with_schema(),
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

if async_flag:
await client.async_run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
file_encoding_strategy="base64",
)
else:
client.run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
file_encoding_strategy="base64",
)

assert mock_predictions_create.called
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
assert (
prediction_payload.get("input", {}).get("file")
== "data:application/octet-stream;base64,aGVsbG8gd29ybGQ="
)


@pytest.mark.asyncio
async def test_run_with_file_output(mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
Expand Down