Skip to content

Commit

Permalink
Allow user to create Part from function_call
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Haley committed Jan 12, 2024
1 parent 1fbf049 commit 1215adb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,50 @@ def test_chat_function_calling(self):
),
)
assert response2.text == "The weather in Boston is super nice!"

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
)
def test_generate_with_function_calling(self):
get_current_weather_func = generative_models.FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT,
)
weather_tool = generative_models.Tool(
function_declarations=[get_current_weather_func],
)

model = generative_models.GenerativeModel(
"gemini-pro",
# Specifying the tools once to avoid specifying them in every request
tools=[weather_tool],
)

messages = [
generative_models.Content(
role="user",
parts=[generative_models.Part.from_text("What is the weather like in Boston?")]
),
generative_models.Content(
role="model",
parts=[generative_models.Part.from_function_call(
function_name="get_current_weather",
arguments={
"location": "Boston"
}
)]
),
generative_models.Part.from_function_response(
name="get_current_weather",
response={
"content": {"weather_there": "super nice"},
},
)
]

response = model.generate_content(contents=messages)

assert response.text == "The weather in Boston is super nice!"
9 changes: 9 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,15 @@ def from_function_response(name: str, response: Dict[str, Any]) -> "Part":
)
)

@staticmethod
def from_function_call(function_name: str, arguments: Dict):
return Part._from_gapic(
raw_part=gapic_tool_types.FunctionCall(
name=function_name,
argumnts=arguments
)
)

def to_dict(self) -> Dict[str, Any]:
return self._raw_part.to_dict()

Expand Down

0 comments on commit 1215adb

Please sign in to comment.