Skip to content

Commit

Permalink
Fix BytesIOProxy: seek-whence & seekable
Browse files Browse the repository at this point in the history
- `BytesIOProxy.seek` - add `whence` argument
- implement `BytesIOProxy.seekable` method
  • Loading branch information
valq7711 committed Aug 5, 2024
1 parent d3074e9 commit ed5bfb4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
21 changes: 16 additions & 5 deletions ombott/request_pkg/multipart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Iterable, Union, Tuple, Dict
from types import SimpleNamespace
from collections import defaultdict
from io import BytesIO
from io import BytesIO, SEEK_CUR, SEEK_SET, SEEK_END
import re

from ombott.request_pkg.errors import BodyParsingError, BodySizeError
Expand Down Expand Up @@ -345,10 +345,21 @@ def __init__(self, src: BytesIO, start: int, end: int) -> None:
def tell(self) -> int:
return self._pos - self._st

def seek(self, pos: int):
if pos < 0:
pos = 0
self._pos = min(self._st + pos, self._end)
def seekable(self) -> bool:
return True

def seek(self, pos: int, whence=SEEK_SET) -> int:
if whence == SEEK_SET:
if pos < 0:
pos = 0
self._pos = min(self._st + pos, self._end)
elif whence == SEEK_CUR:
self.seek(self.tell() + pos)
elif whence == SEEK_END:
self.seek(self._end + pos - self._st)
else:
raise ValueError(f'Unexpected whence: {whence}')
return self.tell()

def read(self, sz: int = None) -> bytes:
max_sz = self._end - self._pos
Expand Down
32 changes: 30 additions & 2 deletions tests/request/test_multipart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterable, List
import pytest
from io import BytesIO
from ombott.request_pkg.multipart import MultipartMarkup, FieldStorage
from io import BytesIO, SEEK_CUR, SEEK_END
from ombott.request_pkg.multipart import MultipartMarkup, FieldStorage, BytesIOProxy


class Field:
Expand Down Expand Up @@ -104,3 +104,31 @@ def test_read_multipart(field_store: Iterable[FieldStorage], form: List[Field]):
else:
assert src.value == parsed.value
assert fields_num == i


def test_bytes_io_proxy():
proxied_bytes_1 = b'some '
proxied_bytes_2 = b'bytes'
proxied_bytes = proxied_bytes_1 + proxied_bytes_2
start = 3
end = start + len(proxied_bytes)
src_body = (b' ' * start) + proxied_bytes + (b' ' * 5)
bytes_src = BytesIO(src_body)
bytes_proxy = BytesIOProxy(bytes_src, start, end)

assert bytes_proxy.read() == proxied_bytes
bytes_proxy.seek(0)
assert bytes_proxy.read(100) == proxied_bytes

bytes_proxy.seek(len(proxied_bytes_1))
assert bytes_proxy.read(len(proxied_bytes_2)) == proxied_bytes_2

bytes_proxy.seek(0)
bytes_proxy.read(len(proxied_bytes_1) - 1)
bytes_proxy.seek(1, SEEK_CUR)
assert bytes_proxy.read() == proxied_bytes_2

bytes_proxy.seek(100)
assert bytes_proxy.tell() == len(proxied_bytes)
bytes_proxy.seek(-len(proxied_bytes_2), SEEK_END)
assert bytes_proxy.read() == proxied_bytes_2

0 comments on commit ed5bfb4

Please sign in to comment.