From 459ba86396ab9260fd7b28a1524c051b7ad300a5 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 29 Jun 2023 18:00:10 -0700 Subject: [PATCH] fix: LLM - Fixed the chat models failing due to safetyAttributes format PiperOrigin-RevId: 544512632 --- tests/unit/aiplatform/test_language_models.py | 24 +++++++++++-------- vertexai/language_models/_language_models.py | 3 ++- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 45c94d6203..f7bf9a9df6 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -164,11 +164,13 @@ } _TEST_CHAT_GENERATION_PREDICTION1 = { - "safetyAttributes": { - "scores": [], - "blocked": False, - "categories": [], - }, + "safetyAttributes": [ + { + "scores": [], + "blocked": False, + "categories": [], + } + ], "candidates": [ { "author": "1", @@ -177,11 +179,13 @@ ], } _TEST_CHAT_GENERATION_PREDICTION2 = { - "safetyAttributes": { - "scores": [], - "blocked": False, - "categories": [], - }, + "safetyAttributes": [ + { + "scores": [], + "blocked": False, + "categories": [], + } + ], "candidates": [ { "author": "1", diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 660fa1c3c4..e6dd7b63b5 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -799,7 +799,8 @@ def send_message( ) prediction = prediction_response.predictions[0] - safety_attributes = prediction["safetyAttributes"] + # ! Note: For chat models, the safetyAttributes is a list. + safety_attributes = prediction["safetyAttributes"][0] response_obj = TextGenerationResponse( text=prediction["candidates"][0]["content"] if prediction.get("candidates")