From 6e6b257318061ff15a017dcae9646a3287295824 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Fri, 8 Dec 2023 19:02:33 +0000 Subject: [PATCH] fix: avoid leaking memory when Client.with_options is used (#316) Fixes https://github.com/openai/openai-python/issues/865. --- pyproject.toml | 2 - src/modern_treasury/_base_client.py | 28 ++++--- src/modern_treasury/_client.py | 4 +- tests/test_client.py | 124 ++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96b3e4ac..6a19633e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,8 +149,6 @@ select = [ "T203", ] ignore = [ - # lru_cache in methods, will be fixed separately - "B019", # mutable defaults "B006", ] diff --git a/src/modern_treasury/_base_client.py b/src/modern_treasury/_base_client.py index 2e5678e8..bbbb8a54 100644 --- a/src/modern_treasury/_base_client.py +++ b/src/modern_treasury/_base_client.py @@ -403,14 +403,12 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: headers_dict = _merge_mappings(self.default_headers, custom_headers) self._validate_headers(headers_dict, custom_headers) + # headers are case-insensitive while dictionaries are not. headers = httpx.Headers(headers_dict) idempotency_header = self._idempotency_header if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: - if not options.idempotency_key: - options.idempotency_key = self._idempotency_key() - - headers[idempotency_header] = options.idempotency_key + headers[idempotency_header] = options.idempotency_key or self._idempotency_key() return headers @@ -594,16 +592,8 @@ def base_url(self) -> URL: def base_url(self, url: URL | str) -> None: self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url)) - @lru_cache(maxsize=None) def platform_headers(self) -> Dict[str, str]: - return { - "X-Stainless-Lang": "python", - "X-Stainless-Package-Version": self._version, - "X-Stainless-OS": str(get_platform()), - "X-Stainless-Arch": str(get_architecture()), - "X-Stainless-Runtime": platform.python_implementation(), - "X-Stainless-Runtime-Version": platform.python_version(), - } + return platform_headers(self._version) def _calculate_retry_timeout( self, @@ -1691,6 +1681,18 @@ def get_platform() -> Platform: return "Unknown" +@lru_cache(maxsize=None) +def platform_headers(version: str) -> Dict[str, str]: + return { + "X-Stainless-Lang": "python", + "X-Stainless-Package-Version": version, + "X-Stainless-OS": str(get_platform()), + "X-Stainless-Arch": str(get_architecture()), + "X-Stainless-Runtime": platform.python_implementation(), + "X-Stainless-Runtime-Version": platform.python_version(), + } + + class OtherArch: def __init__(self, name: str) -> None: self.name = name diff --git a/src/modern_treasury/_client.py b/src/modern_treasury/_client.py index ae160982..4eac303f 100644 --- a/src/modern_treasury/_client.py +++ b/src/modern_treasury/_client.py @@ -291,7 +291,7 @@ def copy( api_key=api_key or self.api_key, organization_id=organization_id or self.organization_id, webhook_key=webhook_key or self.webhook_key, - base_url=base_url or str(self.base_url), + base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, connection_pool_limits=connection_pool_limits, @@ -609,7 +609,7 @@ def copy( api_key=api_key or self.api_key, organization_id=organization_id or self.organization_id, webhook_key=webhook_key or self.webhook_key, - base_url=base_url or str(self.base_url), + base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, connection_pool_limits=connection_pool_limits, diff --git a/tests/test_client.py b/tests/test_client.py index c3cc7b51..a6c126ae 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,10 +2,12 @@ from __future__ import annotations +import gc import os import json import asyncio import inspect +import tracemalloc from typing import Any, Union, cast from unittest import mock @@ -213,6 +215,67 @@ def test_copy_signature(self) -> None: copy_param = copy_signature.parameters.get(name) assert copy_param is not None, f"copy() signature is missing the {name} param" + def test_copy_build_request(self) -> None: + options = FinalRequestOptions(method="get", url="/foo") + + def build_request(options: FinalRequestOptions) -> None: + client = self.client.copy() + client._build_request(options) + + # ensure that the machinery is warmed up before tracing starts. + build_request(options) + gc.collect() + + tracemalloc.start(1000) + + snapshot_before = tracemalloc.take_snapshot() + + ITERATIONS = 10 + for _ in range(ITERATIONS): + build_request(options) + gc.collect() + + snapshot_after = tracemalloc.take_snapshot() + + tracemalloc.stop() + + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: + if diff.count == 0: + # Avoid false positives by considering only leaks (i.e. allocations that persist). + return + + if diff.count % ITERATIONS != 0: + # Avoid false positives by considering only leaks that appear per iteration. + return + + for frame in diff.traceback: + if any( + frame.filename.endswith(fragment) + for fragment in [ + # to_raw_response_wrapper leaks through the @functools.wraps() decorator. + # + # removing the decorator fixes the leak for reasons we don't understand. + "modern_treasury/_response.py", + # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. + "modern_treasury/_compat.py", + # Standard library leaks we don't care about. + "/logging/__init__.py", + ] + ): + return + + leaks.append(diff) + + leaks: list[tracemalloc.StatisticDiff] = [] + for diff in snapshot_after.compare_to(snapshot_before, "traceback"): + add_leak(leaks, diff) + if leaks: + for leak in leaks: + print("MEMORY LEAK:", leak) + for frame in leak.traceback: + print(frame) + raise AssertionError() + def test_request_timeout(self) -> None: request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1061,6 +1124,67 @@ def test_copy_signature(self) -> None: copy_param = copy_signature.parameters.get(name) assert copy_param is not None, f"copy() signature is missing the {name} param" + def test_copy_build_request(self) -> None: + options = FinalRequestOptions(method="get", url="/foo") + + def build_request(options: FinalRequestOptions) -> None: + client = self.client.copy() + client._build_request(options) + + # ensure that the machinery is warmed up before tracing starts. + build_request(options) + gc.collect() + + tracemalloc.start(1000) + + snapshot_before = tracemalloc.take_snapshot() + + ITERATIONS = 10 + for _ in range(ITERATIONS): + build_request(options) + gc.collect() + + snapshot_after = tracemalloc.take_snapshot() + + tracemalloc.stop() + + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: + if diff.count == 0: + # Avoid false positives by considering only leaks (i.e. allocations that persist). + return + + if diff.count % ITERATIONS != 0: + # Avoid false positives by considering only leaks that appear per iteration. + return + + for frame in diff.traceback: + if any( + frame.filename.endswith(fragment) + for fragment in [ + # to_raw_response_wrapper leaks through the @functools.wraps() decorator. + # + # removing the decorator fixes the leak for reasons we don't understand. + "modern_treasury/_response.py", + # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. + "modern_treasury/_compat.py", + # Standard library leaks we don't care about. + "/logging/__init__.py", + ] + ): + return + + leaks.append(diff) + + leaks: list[tracemalloc.StatisticDiff] = [] + for diff in snapshot_after.compare_to(snapshot_before, "traceback"): + add_leak(leaks, diff) + if leaks: + for leak in leaks: + print("MEMORY LEAK:", leak) + for frame in leak.traceback: + print(frame) + raise AssertionError() + async def test_request_timeout(self) -> None: request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore