From ed5bfb413da1756bfd89bf412d5d2161bee70465 Mon Sep 17 00:00:00 2001 From: valq7711 Date: Sun, 7 Jul 2024 14:00:44 +0300 Subject: [PATCH] Fix `BytesIOProxy`: seek-whence & seekable - `BytesIOProxy.seek` - add `whence` argument - implement `BytesIOProxy.seekable` method --- ombott/request_pkg/multipart.py | 21 ++++++++++++++++----- tests/request/test_multipart.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/ombott/request_pkg/multipart.py b/ombott/request_pkg/multipart.py index fc5b7e6..065ce86 100644 --- a/ombott/request_pkg/multipart.py +++ b/ombott/request_pkg/multipart.py @@ -1,7 +1,7 @@ from typing import Optional, Iterable, Union, Tuple, Dict from types import SimpleNamespace from collections import defaultdict -from io import BytesIO +from io import BytesIO, SEEK_CUR, SEEK_SET, SEEK_END import re from ombott.request_pkg.errors import BodyParsingError, BodySizeError @@ -345,10 +345,21 @@ def __init__(self, src: BytesIO, start: int, end: int) -> None: def tell(self) -> int: return self._pos - self._st - def seek(self, pos: int): - if pos < 0: - pos = 0 - self._pos = min(self._st + pos, self._end) + def seekable(self) -> bool: + return True + + def seek(self, pos: int, whence=SEEK_SET) -> int: + if whence == SEEK_SET: + if pos < 0: + pos = 0 + self._pos = min(self._st + pos, self._end) + elif whence == SEEK_CUR: + self.seek(self.tell() + pos) + elif whence == SEEK_END: + self.seek(self._end + pos - self._st) + else: + raise ValueError(f'Unexpected whence: {whence}') + return self.tell() def read(self, sz: int = None) -> bytes: max_sz = self._end - self._pos diff --git a/tests/request/test_multipart.py b/tests/request/test_multipart.py index 06cc609..2785552 100644 --- a/tests/request/test_multipart.py +++ b/tests/request/test_multipart.py @@ -1,7 +1,7 @@ from typing import Iterable, List import pytest -from io import BytesIO -from ombott.request_pkg.multipart import MultipartMarkup, FieldStorage +from io import BytesIO, SEEK_CUR, SEEK_END +from ombott.request_pkg.multipart import MultipartMarkup, FieldStorage, BytesIOProxy class Field: @@ -104,3 +104,31 @@ def test_read_multipart(field_store: Iterable[FieldStorage], form: List[Field]): else: assert src.value == parsed.value assert fields_num == i + + +def test_bytes_io_proxy(): + proxied_bytes_1 = b'some ' + proxied_bytes_2 = b'bytes' + proxied_bytes = proxied_bytes_1 + proxied_bytes_2 + start = 3 + end = start + len(proxied_bytes) + src_body = (b' ' * start) + proxied_bytes + (b' ' * 5) + bytes_src = BytesIO(src_body) + bytes_proxy = BytesIOProxy(bytes_src, start, end) + + assert bytes_proxy.read() == proxied_bytes + bytes_proxy.seek(0) + assert bytes_proxy.read(100) == proxied_bytes + + bytes_proxy.seek(len(proxied_bytes_1)) + assert bytes_proxy.read(len(proxied_bytes_2)) == proxied_bytes_2 + + bytes_proxy.seek(0) + bytes_proxy.read(len(proxied_bytes_1) - 1) + bytes_proxy.seek(1, SEEK_CUR) + assert bytes_proxy.read() == proxied_bytes_2 + + bytes_proxy.seek(100) + assert bytes_proxy.tell() == len(proxied_bytes) + bytes_proxy.seek(-len(proxied_bytes_2), SEEK_END) + assert bytes_proxy.read() == proxied_bytes_2