Skip to content

Commit

Permalink
fix: LLM - Fixed the chat models failing due to safetyAttributes format
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544512632
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 30, 2023
1 parent 970970e commit 459ba86
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
24 changes: 14 additions & 10 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@
}

_TEST_CHAT_GENERATION_PREDICTION1 = {
"safetyAttributes": {
"scores": [],
"blocked": False,
"categories": [],
},
"safetyAttributes": [
{
"scores": [],
"blocked": False,
"categories": [],
}
],
"candidates": [
{
"author": "1",
Expand All @@ -177,11 +179,13 @@
],
}
_TEST_CHAT_GENERATION_PREDICTION2 = {
"safetyAttributes": {
"scores": [],
"blocked": False,
"categories": [],
},
"safetyAttributes": [
{
"scores": [],
"blocked": False,
"categories": [],
}
],
"candidates": [
{
"author": "1",
Expand Down
3 changes: 2 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,8 @@ def send_message(
)

prediction = prediction_response.predictions[0]
safety_attributes = prediction["safetyAttributes"]
# ! Note: For chat models, the safetyAttributes is a list.
safety_attributes = prediction["safetyAttributes"][0]
response_obj = TextGenerationResponse(
text=prediction["candidates"][0]["content"]
if prediction.get("candidates")
Expand Down

0 comments on commit 459ba86

Please sign in to comment.