diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index 29e717d065..d25d6318f5 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -259,3 +259,50 @@ def test_chat_function_calling(self): ), ) assert response2.text == "The weather in Boston is super nice!" + + @mock.patch.object( + target=prediction_service.PredictionServiceClient, + attribute="stream_generate_content", + new=mock_stream_generate_content, + ) + def test_generate_with_function_calling(self): + get_current_weather_func = generative_models.FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT, + ) + weather_tool = generative_models.Tool( + function_declarations=[get_current_weather_func], + ) + + model = generative_models.GenerativeModel( + "gemini-pro", + # Specifying the tools once to avoid specifying them in every request + tools=[weather_tool], + ) + + messages = [ + generative_models.Content( + role="user", + parts=[generative_models.Part.from_text("What is the weather like in Boston?")] + ), + generative_models.Content( + role="model", + parts=[generative_models.Part.from_function_call( + function_name="get_current_weather", + arguments={ + "location": "Boston" + } + )] + ), + generative_models.Part.from_function_response( + name="get_current_weather", + response={ + "content": {"weather_there": "super nice"}, + }, + ) + ] + + response = model.generate_content(contents=messages) + + assert response.text == "The weather in Boston is super nice!" \ No newline at end of file diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 4f6dddf2ca..7e1aa7f285 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -1502,6 +1502,15 @@ def from_function_response(name: str, response: Dict[str, Any]) -> "Part": ) ) + @staticmethod + def from_function_call(function_name: str, arguments: Dict): + return Part._from_gapic( + raw_part=gapic_tool_types.FunctionCall( + name=function_name, + argumnts=arguments + ) + ) + def to_dict(self) -> Dict[str, Any]: return self._raw_part.to_dict()