From 72ae798f518ab1f104cbbcbbdf69361b034a5651 Mon Sep 17 00:00:00 2001 From: BJ Hargrave Date: Tue, 21 Jan 2025 17:12:12 -0500 Subject: [PATCH] Retry for replicate completion response of status=processing We use the DEFAULT_REPLICATE_ constants for retry count and initial delay. If the completion response returns status=processing, we loop to retry. Fixes https://github.com/BerriAI/litellm/issues/7900 Signed-off-by: BJ Hargrave --- litellm/llms/replicate/chat/handler.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py index 31d55729b754..e7d0d383e2f0 100644 --- a/litellm/llms/replicate/chat/handler.py +++ b/litellm/llms/replicate/chat/handler.py @@ -196,11 +196,16 @@ def completion( ) return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore else: - for _ in range(litellm.DEFAULT_MAX_RETRIES): + for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES): time.sleep( - 1 - ) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing" + litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry + ) # wait to allow response to be generated by replicate - else partial output is generated with status=="processing" response = httpx_client.get(url=prediction_url, headers=headers) + if ( + response.status_code == 200 + and response.json().get("status") == "processing" + ): + continue return litellm.ReplicateConfig().transform_response( model=model, raw_response=response, @@ -259,11 +264,16 @@ async def async_completion( ) return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore - for _ in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES): + for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES): await asyncio.sleep( - litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS - ) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing" + litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry + ) # wait to allow response to be generated by replicate - else partial output is generated with status=="processing" response = await async_handler.get(url=prediction_url, headers=headers) + if ( + response.status_code == 200 + and response.json().get("status") == "processing" + ): + continue return litellm.ReplicateConfig().transform_response( model=model, raw_response=response,