From f2f33c6dca1e5574fe902ccb527d1fc725f8fcfd Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 22 Apr 2024 18:50:53 +0200 Subject: [PATCH] fixed retries (#173) --- .../vertexai/langchain_google_vertexai/chat_models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index c1f929fd..32dd3f74 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -107,6 +107,7 @@ _format_to_vertex_tool, _format_functions_to_vertex_tool_dict, ) +from google.api_core.exceptions import GoogleAPIError logger = logging.getLogger(__name__) @@ -442,14 +443,14 @@ def _completion_with_retry_inner(generation_method: Callable, **kwargs: Any) -> chunks = list(response) for chunk in chunks: if not chunk.candidates: - raise ValueError("Got 0 candidates from generations.") + raise GoogleAPIError("Got 0 candidates from generations.") return iter(chunks) if kwargs.get("stream"): return response if len(response.candidates): return response else: - raise ValueError("Got 0 candidates from generations.") + raise GoogleAPIError("Got 0 candidates from generations.") return _completion_with_retry_inner(generation_method, **kwargs) @@ -476,14 +477,14 @@ async def _completion_with_retry_inner( chunks = list(response) for chunk in chunks: if not chunk.candidates: - raise ValueError("Got 0 candidates from generations.") + raise GoogleAPIError("Got 0 candidates from generations.") return iter(chunks) if kwargs.get("stream"): return response if len(response.candidates): return response else: - raise ValueError("Got 0 candidates from generations.") + raise GoogleAPIError("Got 0 candidates from generations.") return await _completion_with_retry_inner(generation_method, **kwargs) @@ -737,6 +738,7 @@ def _stream( client.generate_content, max_retries=self.max_retries, contents=contents, + stream=True, check_stream_response_for_candidates=self.check_stream_response_for_candidates, **params, ) @@ -790,6 +792,7 @@ async def _astream( client.generate_content_async, max_retries=self.max_retries, contents=contents, + stream=True, check_stream_response_for_candidates=self.check_stream_response_for_candidates, **params, ):