Skip to content

Commit

Permalink
feat(client): improve file upload types
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Oct 24, 2023
1 parent 26c8928 commit fa70817
Show file tree
Hide file tree
Showing 10 changed files with 295 additions and 19 deletions.
6 changes: 5 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
[mypy]
pretty = True
show_error_codes = True
exclude = _dev

# Exclude _files.py because mypy isn't smart enough to apply
# the correct type narrowing and as this is an internal module
# it's fine to just use Pyright.
exclude = ^(src/finch/_files\.py|_dev/.*\.py)$

strict_equality = True
implicit_reexport = True
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dev-dependencies = [
"isort==5.10.1",
"time-machine==2.9.0",
"nox==2023.4.22",
"dirty-equals>=0.6.0",

]

Expand All @@ -53,6 +54,15 @@ format = { chain = [
"format:ruff" = "ruff --fix ."
"format:isort" = "isort ."

typecheck = { chain = [
"typecheck:pyright",
"typecheck:verify-types",
"typecheck:mypy"
]}
"typecheck:pyright" = "pyright"
"typecheck:verify-types" = "pyright --verifytypes finch --ignoreexternal"
"typecheck:mypy" = "mypy --enable-incomplete-feature=Unpack ."

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ black==23.3.0
certifi==2023.7.22
click==8.1.7
colorlog==6.7.0
dirty-equals==0.6.0
distlib==0.3.7
distro==1.8.0
exceptiongroup==1.1.3
Expand All @@ -40,6 +41,7 @@ pyright==1.1.332
pytest==7.1.1
pytest-asyncio==0.21.1
python-dateutil==2.8.2
pytz==2023.3.post1
respx==0.19.2
rfc3986==1.5.0
ruff==0.0.282
Expand Down
17 changes: 13 additions & 4 deletions src/finch/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from . import _exceptions
from ._qs import Querystring
from ._files import to_httpx_files, async_to_httpx_files
from ._types import (
NOT_GIVEN,
Body,
Expand Down Expand Up @@ -1088,7 +1089,9 @@ def post(
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options)
opts = FinalRequestOptions.construct(
method="post", url=path, json_data=body, files=to_httpx_files(files), **options
)
return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))

def patch(
Expand All @@ -1111,7 +1114,9 @@ def put(
files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
opts = FinalRequestOptions.construct(method="put", url=path, json_data=body, files=files, **options)
opts = FinalRequestOptions.construct(
method="put", url=path, json_data=body, files=to_httpx_files(files), **options
)
return self.request(cast_to, opts)

def delete(
Expand Down Expand Up @@ -1491,7 +1496,9 @@ async def post(
stream: bool = False,
stream_cls: type[_AsyncStreamT] | None = None,
) -> ResponseT | _AsyncStreamT:
opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options)
opts = FinalRequestOptions.construct(
method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options
)
return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)

async def patch(
Expand All @@ -1514,7 +1521,9 @@ async def put(
files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
opts = FinalRequestOptions.construct(method="put", url=path, json_data=body, files=files, **options)
opts = FinalRequestOptions.construct(
method="put", url=path, json_data=body, files=await async_to_httpx_files(files), **options
)
return await self.request(cast_to, opts)

async def delete(
Expand Down
122 changes: 122 additions & 0 deletions src/finch/_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import io
import os
import pathlib
from typing import overload
from typing_extensions import TypeGuard

import anyio

from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t


def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
)


def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead."
) from None


@overload
def to_httpx_files(files: None) -> None:
...


@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
...


def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None

if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")

return files


def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())

return file

if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")


def _read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file


@overload
async def async_to_httpx_files(files: None) -> None:
...


@overload
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
...


async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None

if is_mapping_t(files):
files = {key: await _async_transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")

return files


async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = anyio.Path(file)
return (path.name, await path.read_bytes())

return file

if is_tuple_t(file):
return (file[0], await _async_read_file_content(file[1]), *file[2:])

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")


async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()

return file
34 changes: 29 additions & 5 deletions src/finch/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, ClassVar, Protocol, final, runtime_checkable
from typing_extensions import (
Unpack,
Literal,
ClassVar,
Protocol,
Required,
TypedDict,
final,
runtime_checkable,
)

import pydantic
import pydantic.generics
Expand All @@ -18,7 +27,7 @@
Timeout,
NotGiven,
AnyMapping,
RequestFiles,
HttpxRequestFiles,
)
from ._utils import is_list, is_mapping, parse_date, parse_datetime, strip_not_given
from ._compat import PYDANTIC_V2, ConfigDict
Expand Down Expand Up @@ -363,6 +372,19 @@ def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore


class FinalRequestOptionsInput(TypedDict, total=False):
method: Required[str]
url: Required[str]
params: Query
headers: Headers
max_retries: int
timeout: float | Timeout | None
files: HttpxRequestFiles | None
idempotency_key: str
json_data: Body
extra_json: AnyMapping


@final
class FinalRequestOptions(pydantic.BaseModel):
method: str
Expand All @@ -371,7 +393,7 @@ class FinalRequestOptions(pydantic.BaseModel):
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
files: Union[RequestFiles, None] = None
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None

# It should be noted that we cannot use `json` here as that would override
Expand All @@ -395,11 +417,13 @@ def get_max_retries(self, max_retries: int) -> int:
# this is necessary as we don't want to do any actual runtime type checking
# (which means we can't use validators) but we do want to ensure that `NotGiven`
# values are not present
#
# type ignore required because we're adding explicit types to `**values`
@classmethod
def construct(
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Any,
**values: Unpack[FinalRequestOptionsInput],
) -> FinalRequestOptions:
kwargs: dict[str, Any] = {
# we unconditionally call `strip_not_given` on any value
Expand Down
17 changes: 16 additions & 1 deletion src/finch/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -32,9 +33,10 @@
_T = TypeVar("_T")

# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
FileContent = Union[IO[bytes], bytes]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
FileTypes = Union[
# file (or bytes)
FileContent,
Expand All @@ -47,6 +49,19 @@
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]

# duplicate of the above but without our custom file support
HttpxFileContent = Union[IO[bytes], bytes]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
Tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]

# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly
Expand Down
5 changes: 5 additions & 0 deletions src/finch/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
from ._utils import is_dict as is_dict
from ._utils import is_list as is_list
from ._utils import is_given as is_given
from ._utils import is_tuple as is_tuple
from ._utils import is_mapping as is_mapping
from ._utils import is_tuple_t as is_tuple_t
from ._utils import parse_date as parse_date
from ._utils import is_sequence as is_sequence
from ._utils import coerce_float as coerce_float
from ._utils import is_list_type as is_list_type
from ._utils import is_mapping_t as is_mapping_t
from ._utils import removeprefix as removeprefix
from ._utils import removesuffix as removesuffix
from ._utils import extract_files as extract_files
from ._utils import is_sequence_t as is_sequence_t
from ._utils import is_union_type as is_union_type
from ._utils import required_args as required_args
from ._utils import coerce_boolean as coerce_boolean
Expand Down
Loading

0 comments on commit fa70817

Please sign in to comment.