From 271968f361ca8452c2d2b30404df9d21c45c7d40 Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Wed, 25 Sep 2024 12:45:37 +0000 Subject: [PATCH] feat(client): allow overriding retry count header (#1745) --- src/openai/_base_client.py | 6 +- tests/test_client.py | 150 +++++++++++++++++++++++++++++++++---- 2 files changed, 138 insertions(+), 18 deletions(-) diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 77e82026ef..c4c9803e74 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -413,8 +413,10 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0 if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: headers[idempotency_header] = options.idempotency_key or self._idempotency_key() - if retries_taken > 0: - headers.setdefault("x-stainless-retry-count", str(retries_taken)) + # Don't set the retry count header if it was already set or removed by the caller. We check + # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case. + if "x-stainless-retry-count" not in (header.lower() for header in custom_headers): + headers["x-stainless-retry-count"] = str(retries_taken) return headers diff --git a/tests/test_client.py b/tests/test_client.py index 567a6ec59f..463174465c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -788,10 +788,71 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success - if failures_before_success == 0: - assert "x-stainless-retry-count" not in response.http_request.headers - else: - assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_omit_retry_count_header( + self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/chat/completions").mock(side_effect=retry_handler) + + response = client.chat.completions.with_raw_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="gpt-4o", + extra_headers={"x-stainless-retry-count": Omit()}, + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_overwrite_retry_count_header( + self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/chat/completions").mock(side_effect=retry_handler) + + response = client.chat.completions.with_raw_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="gpt-4o", + extra_headers={"x-stainless-retry-count": "42"}, + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -822,10 +883,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: model="gpt-4o", ) as response: assert response.retries_taken == failures_before_success - if failures_before_success == 0: - assert "x-stainless-retry-count" not in response.http_request.headers - else: - assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success class TestAsyncOpenAI: @@ -1590,10 +1648,73 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success - if failures_before_success == 0: - assert "x-stainless-retry-count" not in response.http_request.headers - else: - assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_omit_retry_count_header( + self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/chat/completions").mock(side_effect=retry_handler) + + response = await client.chat.completions.with_raw_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="gpt-4o", + extra_headers={"x-stainless-retry-count": Omit()}, + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_overwrite_retry_count_header( + self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/chat/completions").mock(side_effect=retry_handler) + + response = await client.chat.completions.with_raw_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="gpt-4o", + extra_headers={"x-stainless-retry-count": "42"}, + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -1625,7 +1746,4 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: model="gpt-4o", ) as response: assert response.retries_taken == failures_before_success - if failures_before_success == 0: - assert "x-stainless-retry-count" not in response.http_request.headers - else: - assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success