diff --git a/.travis.yml b/.travis.yml index 8d951e75..116fdabc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ matrix: - python: '2.7' env: - SO_DISABLE_MOCKS: "1" + - SO_DISABLE_MOTO_SERVER: "1" - SO_S3_URL: "s3://smart-open-py27-benchmark" - SO_S3_RESULT_URL: "s3://smart-open-py27-benchmark-results" @@ -19,12 +20,14 @@ matrix: - python: '3.6' env: - SO_DISABLE_MOCKS: "1" + - SO_DISABLE_MOTO_SERVER: "1" - SO_S3_URL: "s3://smart-open-py36-benchmark" - SO_S3_RESULT_URL: "s3://smart-open-py36-benchmark-results" - python: '3.7' env: - SO_DISABLE_MOCKS: "1" + - SO_DISABLE_MOTO_SERVER: "1" - SO_S3_URL: "s3://smart-open-py37-benchmark" - SO_S3_RESULT_URL: "s3://smart-open-py37-benchmark-results" - BOTO_CONFIG: "/dev/null" @@ -33,6 +36,7 @@ matrix: install: - pip install --upgrade setuptools + - pip install flask - pip install .[test] - pip install flake8 - pip freeze @@ -45,6 +49,9 @@ script: unset SO_S3_URL; unset SO_S3_RESULT_URL; fi + - if [[ ${SO_DISABLE_MOTO_SERVER} -ne 1 ]]; then + sh -c "moto_server -p5000 2> /dev/null &"; + fi - flake8 --max-line-length=110 - python setup.py test - export SO_S3_URL=$SO_S3_URL/$(python -c 'from uuid import uuid4;print(uuid4())') diff --git a/smart_open/s3.py b/smart_open/s3.py index ba631317..db77e8b6 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -19,6 +19,8 @@ import smart_open.bytebuffer +from botocore.exceptions import IncompleteReadError + logger = logging.getLogger(__name__) # Multiprocessing is unavailable in App Engine (and possibly other sandboxes). @@ -170,27 +172,30 @@ def __init__(self, s3_object, content_length, version_id=None): self._object = s3_object self._content_length = content_length self._version_id = version_id - self.seek(0) + self._position = 0 + self._body = None def seek(self, position): """Seek to the specified position (byte offset) in the S3 key. :param int position: The byte offset from the beginning of the key. """ - self._position = position - range_string = make_range_string(self._position) - logger.debug('content_length: %r range_string: %r', self._content_length, range_string) - # # Close old body explicitly. - # When first seek(), self._body is not exist. Catch the exception and do nothing. + # When first seek() after __init__(), self._body is not exist. # - try: + if self._body is not None: self._body.close() - except AttributeError: - pass + self._body = None + self._position = position + + def _load_body(self): + """Build a continuous connection with the remote peer starts from the current postion. + """ + range_string = make_range_string(self._position) + logger.debug('content_length: %r range_string: %r', self._content_length, range_string) - if position == self._content_length == 0 or position == self._content_length: + if self._position == self._content_length == 0 or self._position == self._content_length: # # When reading, we can't seek to the first byte of an empty file. # Similarly, we can't seek past the last byte. Do nothing here. @@ -199,13 +204,27 @@ def seek(self, position): else: self._body = _get(self._object, self._version_id, Range=range_string)['Body'] - def read(self, size=-1): - if self._position >= self._content_length: - return b'' + def _read_from_body(self, size=-1): if size == -1: binary = self._body.read() else: binary = self._body.read(size) + return binary + + def read(self, size=-1): + """Read from the continuous connection with the remote peer.""" + if self._position >= self._content_length: + return b'' + if self._body is None: + # When the first read() after __init__() or seek(), self._body is not exist. + self._load_body() + + try: + binary = self._read_from_body(size) + except IncompleteReadError: + # The underlying connection of the self._body was closed by the remote peer. + self._load_body() + binary = self._read_from_body(size) self._position += len(binary) return binary diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index c213a7cb..a493af60 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -34,6 +34,7 @@ KEY_NAME = 'test-key' WRITE_KEY_NAME = 'test-write-key' DISABLE_MOCKS = os.environ.get('SO_DISABLE_MOCKS') == "1" +DISABLE_MOTO_SERVER = os.environ.get("SO_DISABLE_MOTO_SERVER") == "1" logger = logging.getLogger(__name__) @@ -104,6 +105,27 @@ def ignore_resource_warnings(): warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") # noqa +@unittest.skipIf(DISABLE_MOTO_SERVER, 'The test case needs a Moto server running on the local 5000 port.') +class SeekableRawReaderTest(unittest.TestCase): + + def setUp(self): + self._local_resource = boto3.resource('s3', endpoint_url='http://localhost:5000') + self._local_resource.Bucket(BUCKET_NAME).create() + self._local_resource.Object(BUCKET_NAME, KEY_NAME).put(Body=b'123456') + + def tearDown(self): + self._local_resource.Object(BUCKET_NAME, KEY_NAME).delete() + self._local_resource.Bucket(BUCKET_NAME).delete() + + def test_read_from_a_closed_body(self): + obj = self._local_resource.Object(BUCKET_NAME, KEY_NAME) + content_length = obj.content_length + reader = smart_open.s3.SeekableRawReader(obj, content_length) + self.assertEqual(reader.read(1), b'1') + reader._body.close() + self.assertEqual(reader.read(2), b'23') + + @maybe_mock_s3 class SeekableBufferedInputBaseTest(unittest.TestCase): def setUp(self):