diff --git a/changelog.d/17239.misc b/changelog.d/17239.misc new file mode 100644 index 00000000000..9fca36bb294 --- /dev/null +++ b/changelog.d/17239.misc @@ -0,0 +1 @@ +Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module. diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 0e875132f6f..9da84959501 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -650,7 +650,7 @@ async def _download_remote_file( file_info = FileInfo(server_name=server_name, file_id=file_id) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + async with self.media_storage.store_into_file(file_info) as (f, fname): try: length, headers = await self.client.download_media( server_name, @@ -693,8 +693,6 @@ async def _download_remote_file( ) raise SynapseError(502, "Failed to fetch remote media") - await finish() - if b"Content-Type" in headers: media_type = headers[b"Content-Type"][0].decode("ascii") else: @@ -1045,14 +1043,9 @@ async def _generate_thumbnails( ), ) - with self.media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): + async with self.media_storage.store_into_file(file_info) as (f, fname): try: await self.media_storage.write_to_file(t_byte_source, f) - await finish() finally: t_byte_source.close() diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index b45b319f5c9..9979c48eaca 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -27,10 +27,9 @@ IO, TYPE_CHECKING, Any, - Awaitable, + AsyncIterator, BinaryIO, Callable, - Generator, Optional, Sequence, Tuple, @@ -97,11 +96,9 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str: the file path written to in the primary media store """ - with self.store_into_file(file_info) as (f, fname, finish_cb): + async with self.store_into_file(file_info) as (f, fname): # Write to the main media repository await self.write_to_file(source, f) - # Write to the other storage providers - await finish_cb() return fname @@ -111,32 +108,27 @@ async def write_to_file(self, source: IO, output: IO) -> None: await defer_to_thread(self.reactor, _write_file_synchronously, source, output) @trace_with_opname("MediaStorage.store_into_file") - @contextlib.contextmanager - def store_into_file( + @contextlib.asynccontextmanager + async def store_into_file( self, file_info: FileInfo - ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: - """Context manager used to get a file like object to write into, as + ) -> AsyncIterator[Tuple[BinaryIO, str]]: + """Async Context manager used to get a file like object to write into, as described by file_info. - Actually yields a 3-tuple (file, fname, finish_cb), where file is a file - like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns an awaitable. + Actually yields a 2-tuple (file, fname,), where file is a file + like object that can be written to and fname is the absolute path of file + on disk. fname can be used to read the contents from after upload, e.g. to generate thumbnails. - finish_cb must be called and waited on after the file has been successfully been - written to. Should not be called if there was an error. Checks for spam and - stores the file into the configured storage providers. - Args: file_info: Info about the file to store Example: - with media_storage.store_into_file(info) as (f, fname, finish_cb): + async with media_storage.store_into_file(info) as (f, fname,): # .. write into f ... - await finish_cb() """ path = self._file_info_to_path(file_info) @@ -145,62 +137,42 @@ def store_into_file( dirname = os.path.dirname(fname) os.makedirs(dirname, exist_ok=True) - finished_called = [False] - main_media_repo_write_trace_scope = start_active_span( "writing to main media repo" ) main_media_repo_write_trace_scope.__enter__() - try: - with open(fname, "wb") as f: - - async def finish() -> None: - # When someone calls finish, we assume they are done writing to the main media repo - main_media_repo_write_trace_scope.__exit__(None, None, None) - - with start_active_span("writing to other storage providers"): - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - with start_active_span(str(provider)): - await provider.store_file(path, file_info) - - finished_called[0] = True - - yield f, fname, finish - except Exception as e: + with main_media_repo_write_trace_scope: try: - main_media_repo_write_trace_scope.__exit__( - type(e), None, e.__traceback__ - ) - os.remove(fname) - except Exception: - pass + with open(fname, "wb") as f: + yield f, fname - raise e from None + except Exception as e: + try: + os.remove(fname) + except Exception: + pass - if not finished_called: - exc = Exception("Finished callback not called") - main_media_repo_write_trace_scope.__exit__( - type(exc), None, exc.__traceback__ + raise e from None + + with start_active_span("writing to other storage providers"): + spam_check = ( + await self._spam_checker_module_callbacks.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) ) - raise exc + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + with start_active_span(str(provider)): + await provider.store_file(path, file_info) async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 3897823b35e..2e65a047897 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -592,7 +592,7 @@ async def _handle_url( file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + async with self.media_storage.store_into_file(file_info) as (f, fname): if url.startswith("data:"): if not allow_data_urls: raise SynapseError( @@ -603,8 +603,6 @@ async def _handle_url( else: download_result = await self._download_url(url, f) - await finish() - try: time_now_ms = self.clock.time_msec() diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 600cbf89630..be4a289ec1b 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -93,13 +93,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - with hs.get_media_repository().media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): - f.write(SMALL_PNG) - self.get_success(finish()) + + media_storage = hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) self.get_success( self.store.store_cached_remote_media( diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py index 88988f3a223..72205c6bb3b 100644 --- a/tests/rest/media/test_domain_blocking.py +++ b/tests/rest/media/test_domain_blocking.py @@ -44,13 +44,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - with hs.get_media_repository().media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): - f.write(SMALL_PNG) - self.get_success(finish()) + + media_storage = hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) self.get_success( self.store.store_cached_remote_media(