From 2fba5549da1d8931e4a92b77f5a6b10a7dac59f2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 9 Jan 2025 17:36:51 +0100 Subject: [PATCH] Remove Pydantic support --- haystack/tools/component_tool.py | 36 +-------------- test/tools/test_component_tool.py | 75 ------------------------------- 2 files changed, 2 insertions(+), 109 deletions(-) diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py index 05ced1c6b6..2c6d43b038 100644 --- a/haystack/tools/component_tool.py +++ b/haystack/tools/component_tool.py @@ -26,16 +26,6 @@ logger = logging.getLogger(__name__) -def is_pydantic_v2_model(instance: Any) -> bool: - """ - Checks if the instance is a Pydantic v2 model. - - :param instance: The instance to check. - :returns: True if the instance is a Pydantic v2 model, False otherwise. - """ - return hasattr(instance, "model_validate") - - class ComponentTool(Tool): """ A Tool that wraps Haystack components, allowing them to be used as tools by LLMs. @@ -282,28 +272,6 @@ def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[s schema["properties"][field.name] = self._create_property_schema(field.type, field_description) return schema - def _create_pydantic_schema(self, python_type: Any, description: str) -> Dict[str, Any]: - """ - Creates a schema for a Pydantic model. - - :param python_type: The Pydantic model type. - :param description: The description of the model. - :returns: A dictionary representing the Pydantic model schema. - """ - schema = {"type": "object", "description": description, "properties": {}} - required_fields = [] - - for m_name, m_field in python_type.model_fields.items(): - field_description = f"Field '{m_name}' of '{python_type.__name__}'." - if isinstance(schema["properties"], dict): - schema["properties"][m_name] = self._create_property_schema(m_field.annotation, field_description) - if m_field.is_required(): - required_fields.append(m_name) - - if required_fields: - schema["required"] = required_fields - return schema - def _create_basic_type_schema(self, python_type: Any, description: str) -> Dict[str, Any]: """ Creates a schema for a basic Python type. @@ -334,8 +302,8 @@ def _create_property_schema(self, python_type: Any, description: str, default: A schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) elif is_dataclass(python_type): schema = self._create_dataclass_schema(python_type, description) - elif is_pydantic_v2_model(python_type): - schema = self._create_pydantic_schema(python_type, description) + elif hasattr(python_type, "model_validate"): + raise ValueError("Pydantic v2 models are not supported as input types for ComponentTool") else: schema = self._create_basic_type_schema(python_type, description) diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py index ac0ee6e793..30f5751730 100644 --- a/test/tools/test_component_tool.py +++ b/test/tools/test_component_tool.py @@ -37,13 +37,6 @@ def run(self, text: str) -> Dict[str, str]: return {"reply": f"Hello, {text}!"} -class Product(BaseModel): - """A product model.""" - - name: str - price: float - - @dataclass class User: """A simple user dataclass.""" @@ -82,21 +75,6 @@ def run(self, texts: List[str]) -> Dict[str, str]: return {"concatenated": " ".join(texts)} -@component -class ProductProcessor: - """A component that processes a Product.""" - - @component.output_types(description=str) - def run(self, product: Product) -> Dict[str, str]: - """ - Creates a description for the product. - - :param product: The Product to process. - :return: A dictionary with the product description. - """ - return {"description": f"The product {product.name} costs ${product.price:.2f}."} - - @dataclass class Address: """A dataclass representing a physical address.""" @@ -218,33 +196,6 @@ def test_from_component_with_list_input(self): assert "concatenated" in result assert result["concatenated"] == "hello world" - def test_from_component_with_pydantic_model(self): - component = ProductProcessor() - - tool = ComponentTool(component=component, name="product_tool", description="A tool that processes products") - - assert tool.parameters == { - "type": "object", - "properties": { - "product": { - "type": "object", - "description": "The Product to process.", - "properties": { - "name": {"type": "string", "description": "Field 'name' of 'Product'."}, - "price": {"type": "number", "description": "Field 'price' of 'Product'."}, - }, - "required": ["name", "price"], - } - }, - "required": ["product"], - } - - # Test tool invocation - result = tool.invoke(product={"name": "Widget", "price": 19.99}) - assert isinstance(result, dict) - assert "description" in result - assert result["description"] == "The product Widget costs $19.99." - def test_from_component_with_nested_dataclass(self): component = PersonProcessor() @@ -438,32 +389,6 @@ def test_list_processor_in_pipeline(self): assert tool_message.tool_call_result.result == str({"concatenated": "hello beautiful world"}) assert not tool_message.tool_call_result.error - @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") - @pytest.mark.integration - def test_product_processor_in_pipeline(self): - component = ProductProcessor() - tool = ComponentTool( - component=component, - name="product_processor", - description="A tool that creates a description for a product with its name and price", - ) - - pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) - pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) - pipeline.connect("llm.replies", "tool_invoker.messages") - - message = ChatMessage.from_user(text="Can you describe a product called Widget that costs $19.99?") - - result = pipeline.run({"llm": {"messages": [message]}}) - tool_messages = result["tool_invoker"]["tool_messages"] - assert len(tool_messages) == 1 - - tool_message = tool_messages[0] - assert tool_message.is_from(ChatRole.TOOL) - assert tool_message.tool_call_result.result == str({"description": "The product Widget costs $19.99."}) - assert not tool_message.tool_call_result.error - @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_person_processor_in_pipeline(self):