diff --git a/.github/workflows/poetry-test.yml b/.github/workflows/poetry-test.yml
index 84e2a770..4709c51c 100644
--- a/.github/workflows/poetry-test.yml
+++ b/.github/workflows/poetry-test.yml
@@ -22,7 +22,7 @@ jobs:
uses: snok/install-poetry@v1
- name: Install python dependencies
- run: poetry install
+ run: poetry install --with gdino
- name: Install ROS 2 dependencies
shell: bash
diff --git a/docs/multimodal_messages.md b/docs/multimodal_messages.md
index f7d14f2b..d80f0ae0 100644
--- a/docs/multimodal_messages.md
+++ b/docs/multimodal_messages.md
@@ -17,7 +17,7 @@ class ToolMultimodalMessage(ToolMessage, MultimodalMessage):
Example:
```python
-from rai.scenario_engine.messages import HumanMultimodalMessage, preprocess_image
+from rai.messages import HumanMultimodalMessage, preprocess_image
from langchain_openai.chat_models import ChatOpenAI
base64_image = preprocess_image('https://raw.githubusercontent.com/RobotecAI/RobotecGPULidar/develop/docs/image/rgl-logo.png')
diff --git a/poetry.lock b/poetry.lock
index 13fd19b1..f50d168e 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -760,6 +760,23 @@ files = [
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
+[[package]]
+name = "deprecated"
+version = "1.2.14"
+description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
+ {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
+]
+
+[package.dependencies]
+wrapt = ">=1.10,<2"
+
+[package.extras]
+dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
+
[[package]]
name = "distlib"
version = "0.3.8"
@@ -2791,9 +2808,9 @@ files = [
[package.dependencies]
numpy = [
+ {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
- {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
@@ -2815,9 +2832,9 @@ files = [
[package.dependencies]
numpy = [
+ {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
- {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
@@ -2965,8 +2982,8 @@ files = [
[package.dependencies]
numpy = [
- {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
+ {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
@@ -5629,4 +5646,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = "^3.10, <3.13"
-content-hash = "82ef220355b2cae5f09374ed45059868a7bbd25ecafed64c6ca028c3f18b16d7"
+content-hash = "d1bc3be63a03d89358a0b71085cc08bbd303f04a154e5ce4588e35269b9d82fc"
diff --git a/pyproject.toml b/pyproject.toml
index fea22084..ce932693 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,7 @@ pypdf = "^4.2.0"
langchain-ollama = "^0.1.1"
streamlit = "^1.37.1"
+deprecated = "^1.2.14"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"
diff --git a/src/rai/package.xml b/src/rai/package.xml
index 4ae99b43..51a7f1be 100644
--- a/src/rai/package.xml
+++ b/src/rai/package.xml
@@ -13,6 +13,10 @@
ament_pep257
python3-pytest
+ nav2_msgs
+ nav2_simple_commander
+ tf_transformations
+
ament_python
diff --git a/src/rai/rai/agents/state_based.py b/src/rai/rai/agents/state_based.py
index c2e3da1d..85528405 100644
--- a/src/rai/rai/agents/state_based.py
+++ b/src/rai/rai/agents/state_based.py
@@ -51,7 +51,7 @@
from langgraph.utils import RunnableCallable
from rclpy.impl.rcutils_logger import RcutilsLogger
-from rai.scenario_engine.messages import (
+from rai.messages import (
HumanMultimodalMessage,
MultimodalArtifact,
ToolMultimodalMessage,
diff --git a/src/rai/rai/documents/loader.py b/src/rai/rai/apps/document_loader.py
similarity index 100%
rename from src/rai/rai/documents/loader.py
rename to src/rai/rai/apps/document_loader.py
diff --git a/src/rai/rai/apps/talk_to_docs.py b/src/rai/rai/apps/talk_to_docs.py
index 78afae97..b784bf49 100644
--- a/src/rai/rai/apps/talk_to_docs.py
+++ b/src/rai/rai/apps/talk_to_docs.py
@@ -30,7 +30,7 @@
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import StateGraph
-from rai.documents.loader import ingest_documentation
+from rai.apps.document_loader import ingest_documentation
logging.basicConfig(level=logging.WARN)
diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai/rai/cli/rai_cli.py
index f4c8ed61..4489446a 100644
--- a/src/rai/rai/cli/rai_cli.py
+++ b/src/rai/rai/cli/rai_cli.py
@@ -23,7 +23,8 @@
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from rai.apps.talk_to_docs import ingest_documentation
-from rai.scenario_engine.messages import HumanMultimodalMessage, preprocess_image
+from rai.messages import preprocess_image
+from rai.messages.multimodal import HumanMultimodalMessage
def parse_whoami_package():
diff --git a/src/rai/rai/communication/README.md b/src/rai/rai/communication/README.md
deleted file mode 100644
index 9ee57684..00000000
--- a/src/rai/rai/communication/README.md
+++ /dev/null
@@ -1,35 +0,0 @@
-# 📘 README for Communication Modules
-
-## 📨 Communication.py
-
-### Overview
-
-This module is designed to handle general communication tasks like sending emails. It's set up to use SMTP (Simple Mail Transfer Protocol) for email operations, allowing the application to send notifications, alerts, or any communication via email.
-
-### What to Implement
-
-- **Add More Communication Methods**: Beyond email, you might want to integrate other forms of communication like SMS, direct messaging services, or even automated phone calls.
-- **Security Enhancements**: Implement more robust security measures for handling credentials and securing communication channels.
-- **Error Handling**: Enhance error handling to manage network issues or authentication errors more gracefully.
-
-### Currently Implemented
-
-- **Email Sending**: Setup to send emails using SMTP with attachments. It includes basic error handling for SMTP authentication errors and checks for missing email credentials.
-
-## 🤖 ros_communication.py
-
-### Overview
-
-This file is specifically tailored for communication within ROS (Robot Operating System) environments. It deals with subscribing to ROS topics, waiting for messages, and processing those messages. It's crucial for robotic applications where real-time data handling and sensor integration are required.
-
-### What to Implement
-
-- **Expand Topic Handling**: Include subscriptions to more diverse topic types and handle different data formats coming from various sensors or inputs.
-- **Integration with More ROS Versions**: Ensure compatibility with different versions of ROS or other robotic middleware.
-- **Enhanced Data Processing**: Implement more complex data processing functions that can convert incoming data into more usable formats or derive more insights.
-
-### Currently Implemented
-
-- **Message Subscription and Retrieval**: Functions to wait for and retrieve messages from specified ROS topics.
-- **Image Data Handling**: Includes a specialized class for grabbing image data from a ROS topic, converting it from ROS image formats to standard encodings.
-- **Utilities for ROS Entities**: Functions to list available topics, nodes, and services in the ROS environment, used for usage agnostic dynamic LLM systems.
diff --git a/src/rai/rai/communication/communication.py b/src/rai/rai/communication/communication.py
deleted file mode 100644
index e6e2bea9..00000000
--- a/src/rai/rai/communication/communication.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import logging
-import os
-import smtplib
-from email.mime.image import MIMEImage
-from email.mime.multipart import MIMEMultipart
-from email.mime.text import MIMEText
-from typing import Optional
-
-
-class EmailSender:
- def __init__(
- self, smtp_server: str, smtp_port: int, logging_level: int = logging.INFO
- ):
- self.smtp_server = smtp_server
- self.smtp_port = smtp_port
- self.sender_email = os.environ.get("ROBOT_ALERT_EMAIL", None)
- self.sender_password = os.environ.get("ROBOT_ALERT_PASSWORD", None)
- self.logger = logging.getLogger(self.__class__.__name__)
- self.logger.setLevel(logging_level)
-
- if self.sender_email is None or self.sender_password is None:
- self.logger.error(
- "Email and password for the alert system are not set. Message will not be sent."
- )
-
- def send_email(
- self,
- recipient_email: str,
- subject: str,
- message: str,
- image_path: Optional[str] = None,
- ) -> None:
- # Create a multipart message
- msg = MIMEMultipart()
- msg["From"] = self.sender_email
- msg["To"] = recipient_email
- msg["Subject"] = subject
- # Attach the message as plain text
- msg.attach(MIMEText(message, "html"))
- # Attach the image if provided
- if image_path:
- with open(image_path, "rb") as f:
- image_data = f.read()
- image = MIMEImage(image_data, name="image.png")
- msg.attach(image)
-
- # Connect to the SMTP server and send the email
- if self.sender_email is None or self.sender_password is None:
- return
-
- with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
- server.starttls()
- try:
- server.login(self.sender_email, self.sender_password)
- except smtplib.SMTPAuthenticationError:
- self.logger.error("Failed to authenticate with the SMTP server.")
- return
- server.send_message(msg)
- self.logger.info(f"Email sent to {recipient_email}.")
diff --git a/src/rai/rai/documents/__init__.py b/src/rai/rai/extensions/__init__.py
similarity index 100%
rename from src/rai/rai/documents/__init__.py
rename to src/rai/rai/extensions/__init__.py
diff --git a/src/rai_hmi/rai_hmi/custom_mavigator.py b/src/rai/rai/extensions/navigator.py
similarity index 74%
rename from src/rai_hmi/rai_hmi/custom_mavigator.py
rename to src/rai/rai/extensions/navigator.py
index 6df7a0bd..bf2fdcb7 100644
--- a/src/rai_hmi/rai_hmi/custom_mavigator.py
+++ b/src/rai/rai/extensions/navigator.py
@@ -1,3 +1,18 @@
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import rclpy
from builtin_interfaces.msg import Duration
from geometry_msgs.msg import Point
diff --git a/src/rai/rai/communication/__init__.py b/src/rai/rai/messages/__init__.py
similarity index 60%
rename from src/rai/rai/communication/__init__.py
rename to src/rai/rai/messages/__init__.py
index e88f3871..656ca0fe 100644
--- a/src/rai/rai/communication/__init__.py
+++ b/src/rai/rai/messages/__init__.py
@@ -13,7 +13,20 @@
# limitations under the License.
#
-from .communication import EmailSender
-from .ros_communication import SingleImageGrabber
+from .multimodal import (
+ AiMultimodalMessage,
+ HumanMultimodalMessage,
+ MultimodalArtifact,
+ SystemMultimodalMessage,
+ ToolMultimodalMessage,
+)
+from .utils import preprocess_image
-__all__ = ["EmailSender", "SingleImageGrabber"]
+__all__ = [
+ "HumanMultimodalMessage",
+ "AiMultimodalMessage",
+ "SystemMultimodalMessage",
+ "ToolMultimodalMessage",
+ "MultimodalArtifact",
+ "preprocess_image",
+]
diff --git a/src/rai/rai/scenario_engine/messages.py b/src/rai/rai/messages/multimodal.py
similarity index 84%
rename from src/rai/rai/scenario_engine/messages.py
rename to src/rai/rai/messages/multimodal.py
index fb7f8e26..e68b2411 100644
--- a/src/rai/rai/scenario_engine/messages.py
+++ b/src/rai/rai/messages/multimodal.py
@@ -13,11 +13,8 @@
# limitations under the License.
#
-import base64
-from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict, Union
+from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
-import numpy as np
-import requests
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages.base import BaseMessage, get_msg_title_repr
from langchain_core.tools import BaseTool
@@ -164,28 +161,3 @@ def __init__(
if not any([tool.__class__.__name__ == stop_tool for tool in tools]):
raise ValueError("Stop tool not in tools")
self.tools: List[BaseTool] = tools
-
-
-def preprocess_image(
- image: Union[str, bytes, np.ndarray[Any, np.dtype[np.uint8]]],
- encoding_function: Callable[[Any], str] = lambda x: base64.b64encode(x).decode(
- "utf-8"
- ),
-) -> str:
- if isinstance(image, str) and image.startswith(("http://", "https://")):
- response = requests.get(image)
- response.raise_for_status()
- image_data = response.content
- elif isinstance(image, str):
- with open(image, "rb") as image_file:
- image_data = image_file.read()
- elif isinstance(image, bytes):
- image_data = image
- encoding_function = lambda x: x.decode("utf-8")
- elif isinstance(image, np.ndarray): # type: ignore
- image_data = image.tobytes()
- encoding_function = lambda x: base64.b64encode(x).decode("utf-8")
- else:
- image_data = image
-
- return encoding_function(image_data)
diff --git a/src/rai/rai/messages/utils.py b/src/rai/rai/messages/utils.py
new file mode 100644
index 00000000..db6a74cc
--- /dev/null
+++ b/src/rai/rai/messages/utils.py
@@ -0,0 +1,45 @@
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import base64
+from typing import Any, Callable, Union
+
+import numpy as np
+import requests
+
+
+def preprocess_image(
+ image: Union[str, bytes, np.ndarray[Any, np.dtype[np.uint8]]],
+ encoding_function: Callable[[Any], str] = lambda x: base64.b64encode(x).decode(
+ "utf-8"
+ ),
+) -> str:
+ if isinstance(image, str) and image.startswith(("http://", "https://")):
+ response = requests.get(image)
+ response.raise_for_status()
+ image_data = response.content
+ elif isinstance(image, str):
+ with open(image, "rb") as image_file:
+ image_data = image_file.read()
+ elif isinstance(image, bytes):
+ image_data = image
+ encoding_function = lambda x: x.decode("utf-8")
+ elif isinstance(image, np.ndarray): # type: ignore
+ image_data = image.tobytes()
+ encoding_function = lambda x: base64.b64encode(x).decode("utf-8")
+ else:
+ image_data = image
+
+ return encoding_function(image_data)
diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py
index 71a32c1b..9bd68e7a 100644
--- a/src/rai/rai/node.py
+++ b/src/rai/rai/node.py
@@ -48,9 +48,9 @@
from std_srvs.srv import Trigger
from rai.agents.state_based import State
-from rai.communication.ros_communication import wait_for_message
-from rai.scenario_engine.messages import HumanMultimodalMessage
+from rai.messages.multimodal import HumanMultimodalMessage
from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str
+from rai.tools.utils import wait_for_message
class RosoutBuffer:
diff --git a/src/rai/rai/scenario_engine/README.md b/src/rai/rai/scenario_engine/README.md
deleted file mode 100644
index cb9311bf..00000000
--- a/src/rai/rai/scenario_engine/README.md
+++ /dev/null
@@ -1,29 +0,0 @@
-# 🎬 ScenarioRunner Module
-
-## Overview
-
-The `ScenarioRunner` module is an essential part of our application. It enables running scenarios, which are series of messages and actions that simulate a conversation or a process. This is great for things like chatbots or automated systems that need to respond to user inputs in a consistent and logical way.
-
-## Key Components
-
-### 🏃 ScenarioRunner Class
-
-- **Purpose**: Manages the execution of different scenarios. Each scenario is made up of parts that can be messages, actions, or decisions.
-- **How it works**: Starts with a scenario, and runs through each part. It can send messages, execute actions, or make decisions based on certain conditions.
-- **Caching**: Can remember previous responses to save time and resources. This is useful when responses are predictable and don't need to be recalculated.
-
-## Saving and Logging
-
-- **Saving**: Can save the entire conversation or scenario outcome as HTML. This is useful for keeping records or reviewing how scenarios unfold.
-- **Logging**: Uses `coloredlogs` to make log messages easier to read. This is great for debugging and understanding the flow of scenarios.
-
-## How to Use This Module
-
-1. **Setup**: Make sure you have defined your scenarios using the available messages.
-2. **Execution**: Use the `run()` method of the `ScenarioRunner` to start executing your scenario.
-3. **Monitoring**: Keep an eye on the logs to understand how your scenario is processing and to catch any issues early.
-4. **Review**: After running scenarios, you can review the saved files to analyze the outcomes and make improvements.
-
-## Customizing Scenarios
-
-- **Expand scenarios**: Add more message types or actions to handle new types of interactions.
diff --git a/src/rai/rai/scenario_engine/__init__.py b/src/rai/rai/scenario_engine/__init__.py
deleted file mode 100644
index f138f42a..00000000
--- a/src/rai/rai/scenario_engine/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
diff --git a/src/rai/rai/scenario_engine/tool_runner.py b/src/rai/rai/scenario_engine/tool_runner.py
deleted file mode 100644
index 27028c63..00000000
--- a/src/rai/rai/scenario_engine/tool_runner.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import logging
-from typing import Any, Dict, List, Literal, Sequence
-
-from langchain.tools import BaseTool
-from langchain_core.messages import AIMessage, BaseMessage, ToolCall, ToolMessage
-
-from rai.scenario_engine.messages import ToolMultimodalMessage
-
-
-def images_to_vendor_format(images: List[str], vendor: str) -> List[Dict[str, Any]]:
- if vendor == "openai":
- return [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{image}",
- },
- }
- for image in images
- ]
- else:
- raise ValueError(f"Vendor {vendor} not supported")
-
-
-def run_tool_call(
- tool_call: ToolCall,
- tools: Sequence[BaseTool],
-) -> Dict[str, Any] | Any:
- logger = logging.getLogger(__name__)
- selected_tool = {k.name: k for k in tools}[tool_call["name"]]
-
- try:
- if selected_tool.args_schema is not None:
- args = selected_tool.args_schema(**tool_call["args"]).dict()
- else:
- args = dict()
- except Exception as e:
- err_msg = f"Error in preparing arguments for {selected_tool.name}: {e}"
- logger.error(err_msg)
- return err_msg
-
- logger.info(f"Running tool: {selected_tool.name} with args: {args}")
-
- try:
- tool_output = selected_tool.run(args)
- except Exception as e:
- err_msg = f"Error in running tool {selected_tool.name}: {e}"
- logger.warning(err_msg)
- return err_msg
-
- logger.info(f"Successfully ran tool: {selected_tool.name}. Output: {tool_output}")
- return tool_output
-
-
-def run_requested_tools(
- ai_msg: AIMessage,
- tools: Sequence[BaseTool],
- messages: List[BaseMessage],
- llm_type: Literal["openai", "bedrock"],
-):
- internal_messages: List[BaseMessage] = []
- for tool_call in ai_msg.tool_calls:
- tool_output = run_tool_call(tool_call, tools)
- assert isinstance(tool_call["id"], str), "Tool output must have an id."
- if isinstance(tool_output, dict):
- tool_message = ToolMultimodalMessage(
- content=tool_output.get("content", "No response from the tool."),
- images=tool_output.get("images"),
- tool_call_id=tool_call["id"],
- )
- tool_message = tool_message.postprocess(format=llm_type)
- else:
- tool_message = [
- ToolMessage(content=str(tool_output), tool_call_id=tool_call["id"])
- ]
- if isinstance(tool_message, list):
- internal_messages.extend(tool_message)
- else:
- internal_messages.append(tool_message)
-
- # because we can't answer an aiMessage with an alternating sequence of tool and human messages
- # we sort the messages by type so that the tool messages are sent first
- # for more information see implementation of ToolMultimodalMessage.postprocess
-
- internal_messages.sort(key=lambda x: x.__class__.__name__, reverse=True)
- messages.extend(internal_messages)
- return messages
diff --git a/src/rai/rai/tools/__init__.py b/src/rai/rai/tools/__init__.py
index 9be117cc..f138f42a 100644
--- a/src/rai/rai/tools/__init__.py
+++ b/src/rai/rai/tools/__init__.py
@@ -12,14 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
-from .hmi_tools import PlayVoiceMessageTool, SendEmailTool, WaitForSecondsTool
-from .planning_tools import AddTaskTool, GetNewTaskTool
-
-__all__ = [
- "PlayVoiceMessageTool",
- "SendEmailTool",
- "WaitForSecondsTool",
- "AddTaskTool",
- "GetNewTaskTool",
-]
diff --git a/src/rai/rai/tools/hmi_tools.py b/src/rai/rai/tools/hmi_tools.py
deleted file mode 100644
index 2eae5983..00000000
--- a/src/rai/rai/tools/hmi_tools.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import time
-from typing import Type
-
-from langchain_core.pydantic_v1 import BaseModel, Field
-from langchain_core.tools import BaseTool
-
-from rai.communication.communication import EmailSender
-
-
-class PlayVoiceMessageToolInput(BaseModel):
- """Input for the PlayVoiceMessageTool tool."""
-
- content: str = Field(..., description="The content of the voice message")
-
-
-class PlayVoiceMessageTool(BaseTool):
- """Output a voice message"""
-
- name: str = "PlayVoiceMessageTool"
- description: str = (
- "A tool for sending voice messages. "
- "Useful for sending audio content as messages. "
- "Input should be the content of the voice message."
- )
-
- args_schema: Type[PlayVoiceMessageToolInput] = PlayVoiceMessageToolInput
-
- def _run(self, content: str):
- raise NotImplementedError
-
-
-class WaitForSecondsToolInput(BaseModel):
- """Input for the WaitForSecondsTool tool."""
-
- seconds: int = Field(..., description="The number of seconds to wait")
-
-
-class WaitForSecondsTool(BaseTool):
- """Wait for a specified number of seconds"""
-
- name: str = "WaitForSecondsTool"
- description: str = (
- "A tool for waiting. "
- "Useful for pausing execution for a specified number of seconds. "
- "Input should be the number of seconds to wait."
- )
-
- args_schema: Type[WaitForSecondsToolInput] = WaitForSecondsToolInput
-
- def _run(self, seconds: int):
- """Waits for the specified number of seconds."""
- time.sleep(seconds)
- return f"Waited for {seconds} seconds."
-
-
-class SendEmailToolInput(BaseModel):
- """Input for the SendEmailToAdminTool tool."""
-
- recipient: str = Field(..., description="The email address of the recipient.")
- subject: str = Field(
- ..., description="The subject of the email. Should be very short."
- )
- content: str = Field(
- ..., description="The content of the email. Should be short and concise."
- )
-
-
-class SendEmailTool(BaseTool):
- """Send an email to the admin"""
-
- name: str = "SendEmailToAdminTool"
- description: str = (
- "A tool for sending emails to the admin. "
- "Useful for sending notifications to the admin. "
- "Input should be the subject and content of the email."
- )
-
- args_schema: Type[SendEmailToolInput] = SendEmailToolInput
-
- def _run(self, recipient: str, subject: str, content: str):
- """Sends an email to the admin."""
- email_sender = EmailSender(smtp_server="", smtp_port=587)
- email_sender.send_email(
- recipient_email=recipient, subject=subject, message=content
- )
- return "Email sent to admin."
diff --git a/src/rai/rai/tools/planning_tools.py b/src/rai/rai/tools/planning_tools.py
deleted file mode 100644
index 38df7807..00000000
--- a/src/rai/rai/tools/planning_tools.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from queue import Queue
-from typing import Type
-
-from langchain.pydantic_v1 import BaseModel, Field
-from langchain_core.tools import BaseTool
-
-
-class AddTaskToolInput(BaseModel):
- task: str = Field(..., title="Task to be added into the task list.")
-
-
-class AddTaskTool(BaseTool):
- name: str = "AddTaskTool"
- description: str = "Add a task to the task list for later execution."
- args_schema: Type[AddTaskToolInput] = AddTaskToolInput
-
- queue: Queue[str]
-
- def _run(self, task: str):
- self.queue.put(task)
- return "Task added to the task list."
-
-
-class GetNewTaskToolInput(BaseModel):
- pass
-
-
-class GetNewTaskTool(BaseTool):
- name: str = "GetNewTaskTool"
- description: str = "Get a new task from the task list."
- args_schema: Type[GetNewTaskToolInput] = GetNewTaskToolInput
-
- queue: Queue[str]
-
- def _run(self):
- if self.queue.empty():
- return "Task list is empty."
- else:
- task = self.queue.get()
- return "Retrieved task: " + task
diff --git a/src/rai/rai/tools/ros/__init__.py b/src/rai/rai/tools/ros/__init__.py
index eccbcf45..7caac4a3 100644
--- a/src/rai/rai/tools/ros/__init__.py
+++ b/src/rai/rai/tools/ros/__init__.py
@@ -13,18 +13,7 @@
# limitations under the License.
#
-from .cat_demo_tools import (
- ContinueActionTool,
- ReplanWithoutCurrentPathTool,
- UseHonkTool,
- UseLightsTool,
-)
from .cli import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
-from .mock_tools import (
- ObserveSurroundingsTool,
- OpenSetSegmentationTool,
- VisualQuestionAnsweringTool,
-)
from .tools import (
AddDescribedWaypointToDatabaseTool,
GetCurrentPositionTool,
@@ -32,13 +21,6 @@
)
__all__ = [
- "ReplanWithoutCurrentPathTool",
- "UseHonkTool",
- "UseLightsTool",
- "ContinueActionTool",
- "OpenSetSegmentationTool",
- "VisualQuestionAnsweringTool",
- "ObserveSurroundingsTool",
"Ros2TopicTool",
"Ros2InterfaceTool",
"Ros2ServiceTool",
diff --git a/src/rai/rai/tools/ros/cat_demo_tools.py b/src/rai/rai/tools/ros/cat_demo_tools.py
deleted file mode 100644
index 4fb2c536..00000000
--- a/src/rai/rai/tools/ros/cat_demo_tools.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import subprocess
-from typing import Type
-
-from langchain.tools import BaseTool
-from langchain_core.pydantic_v1 import BaseModel
-
-
-class UseLightsToolInput(BaseModel):
- """Input for the use lights tool."""
-
-
-class UseLightsTool(BaseTool):
- """Use the lights"""
-
- name: str = "UseLightsTool"
- description: str = "Turns on the lights."
- args_schema: Type[UseLightsToolInput] = UseLightsToolInput
-
- def _run(self):
- """Turns on the lights."""
- result = subprocess.run(
- "echo 'Lights have been turned on'",
- shell=True,
- capture_output=True,
- text=True,
- )
- return result.stdout.strip()
-
-
-class UseHonkToolInput(BaseModel):
- """Input for the use honk tool."""
-
-
-class UseHonkTool(BaseTool):
- """Use the honk"""
-
- name: str = "UseHonkTool"
- description: str = "Activates the honk."
- args_schema: Type[UseHonkToolInput] = UseHonkToolInput
-
- def _run(self):
- """Activates the honk."""
- result = subprocess.run(
- "echo 'Activating honk'", shell=True, capture_output=True, text=True
- )
- return result.stdout.strip()
-
-
-class ReplanWithoutCurrentPathToolInput(BaseModel):
- """Input for the replan without current path tool."""
-
-
-class ReplanWithoutCurrentPathTool(BaseTool):
- """Replan without current path"""
-
- name: str = "ReplanWithoutCurrentPathTool"
- description: str = "Replans without the current path."
- args_schema: Type[ReplanWithoutCurrentPathToolInput] = (
- ReplanWithoutCurrentPathToolInput
- )
-
- def _run(self):
- """Replans without the current path."""
- result = subprocess.run(
- "echo 'Replanning without current path'",
- shell=True,
- capture_output=True,
- text=True,
- )
- return result.stdout.strip()
-
-
-class ContinueActionToolInput(BaseModel):
- """Input for the continue action tool."""
-
-
-class ContinueActionTool(BaseTool):
- """Continue action"""
-
- name: str = "ContinueActionTool"
- description: str = "Continues the current operation."
- args_schema: Type[ContinueActionToolInput] = ContinueActionToolInput
-
- def _run(self):
- """Continues the current operation."""
- result = subprocess.run(
- "echo 'Continuing'", shell=True, capture_output=True, text=True
- )
- return result.stdout.strip()
-
-
-class StopToolInput(BaseModel):
- """Input for the stop tool."""
-
-
-class StopTool(BaseTool):
- """Stop action"""
-
- name: str = "StopTool"
- description: str = "Stops the current operation."
- args_schema: Type[StopToolInput] = StopToolInput
-
- def _run(self):
- """Stops the current operation."""
- result = subprocess.run(
- "echo 'Stopping'", shell=True, capture_output=True, text=True
- )
- return result.stdout.strip()
-
-
-class FinishToolInput(BaseModel):
- """Input for the finish tool."""
-
-
-class FinishTool(BaseTool):
- """Finish the conversation. Does not impact the actual mission."""
-
- name: str = "FinishTool"
- description: str = "Ends the conversation."
- args_schema: Type[FinishToolInput] = FinishToolInput
-
- def _run(self):
- """Ends the conversation."""
- return "Conversation finished."
diff --git a/src/rai/rai/tools/ros/cli.py b/src/rai/rai/tools/ros/cli.py
index 92dcd9f6..87b6f004 100644
--- a/src/rai/rai/tools/ros/cli.py
+++ b/src/rai/rai/tools/ros/cli.py
@@ -131,32 +131,3 @@ def _run(self, command: str):
command = f"ros2 service {command}"
result = subprocess.run(command, shell=True, capture_output=True)
return result
-
-
-class SetGoalPoseToolInput(BaseModel):
- """Input for the set_goal_pose tool."""
-
- topic: str = Field(
- "/goal_pose", description="Ros2 topic to publish the goal pose to"
- )
- x: float = Field(..., description="The x coordinate of the goal pose")
- y: float = Field(..., description="The y coordinate of the goal pose")
-
-
-class SetGoalPoseTool(BaseTool):
- """Set the goal pose for the robot"""
-
- name = "SetGoalPoseTool"
- description: str = "A tool for setting the goal pose for the robot."
-
- args_schema: Type[SetGoalPoseToolInput] = SetGoalPoseToolInput
-
- def _run(self, topic: str, x: float, y: float):
- """Sets the goal pose for the robot."""
-
- cmd = (
- f"ros2 topic pub {topic} geometry_msgs/PoseStamped "
- f'\'{{header: {{stamp: {{sec: 0, nanosec: 0}}, frame_id: "map"}}, '
- f"pose: {{position: {{x: {x}, y: {y}, z: {0.0}}}}}}}' --once"
- )
- return subprocess.run(cmd, shell=True)
diff --git a/src/rai/rai/tools/ros/mock_tools.py b/src/rai/rai/tools/ros/mock_tools.py
deleted file mode 100644
index f6503c8f..00000000
--- a/src/rai/rai/tools/ros/mock_tools.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from typing import List, Type
-
-from langchain.tools import BaseTool
-from langchain_core.pydantic_v1 import BaseModel, Field
-
-
-class OpenSetSegmentationToolInput(BaseModel):
- """Input for the open set segmentation tool."""
-
- topic: str = Field(..., description="ROS2 image topic to subscribe to")
- classes: List[str] = Field(..., description="Classes to segment")
-
-
-class OpenSetSegmentationTool(BaseTool):
- """Get the segmentation of an image into any list of classes"""
-
- name: str = "OpenSetSegmentationTool"
- description: str = (
- "Segments an image into specified classes from a given ROS2 topic."
- )
- args_schema: Type[OpenSetSegmentationToolInput] = OpenSetSegmentationToolInput
-
- def _run(self, topic: str, classes: List[str]):
- """Implements the segmentation logic for the specified classes on the given topic."""
- return f"Segmentation on topic {topic} for classes {classes} started."
-
-
-class VisualQuestionAnsweringToolInput(BaseModel):
- """Input for the visual question answering tool."""
-
- topic: str = Field(..., description="ROS2 image topic to subscribe to")
- question: str = Field(..., description="Question about the image")
-
-
-class VisualQuestionAnsweringTool(BaseTool):
- """Ask a question about an image"""
-
- name: str = "VisualQuestionAnsweringTool"
- description: str = (
- "Processes an image from a ROS2 topic and answers a specified question."
- )
- args_schema: Type[VisualQuestionAnsweringToolInput] = (
- VisualQuestionAnsweringToolInput
- )
-
- def _run(self, topic: str, question: str):
- """Processes the image from the specified topic and answers the given question."""
- return f"Processing and answering question about {topic}: {question}"
-
-
-class ObserveSurroundingsToolInput(BaseModel):
- """Input for the observe surroundings tool."""
-
- topic: str = Field(..., description="ROS2 image topic to subscribe to")
-
-
-class ObserveSurroundingsTool(BaseTool):
- """Observe the surroundings"""
-
- name: str = "ObserveSurroundingsTool"
- description: str = "Observes and processes data from a given ROS2 topic."
- args_schema: Type[ObserveSurroundingsToolInput] = ObserveSurroundingsToolInput
-
- def _run(self, topic: str):
- """Observes and processes data from the given ROS2 topic."""
- return f"Observing surroundings using topic {topic}"
diff --git a/src/rai/rai/tools/ros/native.py b/src/rai/rai/tools/ros/native.py
index 3ee80d5c..dd0077fc 100644
--- a/src/rai/rai/tools/ros/native.py
+++ b/src/rai/rai/tools/ros/native.py
@@ -68,7 +68,8 @@ class PubRos2MessageToolInput(BaseModel):
# --------------------- Tools ---------------------
class Ros2BaseTool(BaseTool):
- node: rclpy.node.Node = Field(..., exclude=True, include=False, required=True)
+ # TODO: Make the decision between rclpy.node.Node and RaiNode
+ node: rclpy.node.Node = Field(..., exclude=True, required=True)
args_schema: Type[Ros2BaseInput] = Ros2BaseInput
diff --git a/src/rai/rai/tools/ros/tools.py b/src/rai/rai/tools/ros/tools.py
index 7281738a..1ce8ae75 100644
--- a/src/rai/rai/tools/ros/tools.py
+++ b/src/rai/rai/tools/ros/tools.py
@@ -27,10 +27,7 @@
from nav_msgs.msg import OccupancyGrid
from tf_transformations import euler_from_quaternion
-from rai.communication.ros_communication import (
- SingleMessageGrabber,
- TF2TransformFetcher,
-)
+from rai.tools.utils import SingleMessageGrabber, TF2TransformFetcher
from .native import TopicInput
diff --git a/src/rai/rai/tools/time.py b/src/rai/rai/tools/time.py
index c34f699b..3ce63844 100644
--- a/src/rai/rai/tools/time.py
+++ b/src/rai/rai/tools/time.py
@@ -14,8 +14,11 @@
#
import time
+from typing import Type
+from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import tool
+from langchain_core.tools import BaseTool
@tool
@@ -25,3 +28,27 @@ def sleep_max_5s(n: int):
n = 5
time.sleep(n)
+
+
+class WaitForSecondsToolInput(BaseModel):
+ """Input for the WaitForSecondsTool tool."""
+
+ seconds: int = Field(..., description="The number of seconds to wait")
+
+
+class WaitForSecondsTool(BaseTool):
+ """Wait for a specified number of seconds"""
+
+ name: str = "WaitForSecondsTool"
+ description: str = (
+ "A tool for waiting. "
+ "Useful for pausing execution for a specified number of seconds. "
+ "Input should be the number of seconds to wait."
+ )
+
+ args_schema: Type[WaitForSecondsToolInput] = WaitForSecondsToolInput
+
+ def _run(self, seconds: int):
+ """Waits for the specified number of seconds."""
+ time.sleep(seconds)
+ return f"Waited for {seconds} seconds."
diff --git a/src/rai/rai/communication/ros_communication.py b/src/rai/rai/tools/utils.py
similarity index 73%
rename from src/rai/rai/communication/ros_communication.py
rename to src/rai/rai/tools/utils.py
index 4cc78ebd..e8797f25 100644
--- a/src/rai/rai/communication/ros_communication.py
+++ b/src/rai/rai/tools/utils.py
@@ -16,12 +16,15 @@
import base64
import logging
import subprocess
-from typing import Any, Callable, Union, cast
+from typing import Any, Callable, Dict, List, Literal, Sequence, Union, cast
import cv2
import rclpy
import rclpy.qos
from cv_bridge import CvBridge
+from deprecated import deprecated
+from langchain.tools import BaseTool
+from langchain_core.messages import AIMessage, BaseMessage, ToolCall, ToolMessage
from rclpy.duration import Duration
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
from rclpy.node import Node
@@ -37,6 +40,8 @@
from sensor_msgs.msg import Image
from tf2_ros import Buffer, TransformListener
+from rai.messages import ToolMultimodalMessage
+
# Copied from https://github.com/ros2/rclpy/blob/jazzy/rclpy/rclpy/wait_for_message.py, to support humble
def wait_for_message(
@@ -91,6 +96,89 @@ def wait_for_message(
return False, None
+@deprecated(reason="Multimodal images are handled using rai.messages.multimodal")
+def images_to_vendor_format(images: List[str], vendor: str) -> List[Dict[str, Any]]:
+ if vendor == "openai":
+ return [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{image}",
+ },
+ }
+ for image in images
+ ]
+ else:
+ raise ValueError(f"Vendor {vendor} not supported")
+
+
+@deprecated(reason="Running tool is langchain.agent based now")
+def run_tool_call(
+ tool_call: ToolCall,
+ tools: Sequence[BaseTool],
+) -> Dict[str, Any] | Any:
+ logger = logging.getLogger(__name__)
+ selected_tool = {k.name: k for k in tools}[tool_call["name"]]
+
+ try:
+ if selected_tool.args_schema is not None:
+ args = selected_tool.args_schema(**tool_call["args"]).dict()
+ else:
+ args = dict()
+ except Exception as e:
+ err_msg = f"Error in preparing arguments for {selected_tool.name}: {e}"
+ logger.error(err_msg)
+ return err_msg
+
+ logger.info(f"Running tool: {selected_tool.name} with args: {args}")
+
+ try:
+ tool_output = selected_tool.run(args)
+ except Exception as e:
+ err_msg = f"Error in running tool {selected_tool.name}: {e}"
+ logger.warning(err_msg)
+ return err_msg
+
+ logger.info(f"Successfully ran tool: {selected_tool.name}. Output: {tool_output}")
+ return tool_output
+
+
+@deprecated(reason="Running tool is langchain.agent based now")
+def run_requested_tools(
+ ai_msg: AIMessage,
+ tools: Sequence[BaseTool],
+ messages: List[BaseMessage],
+ llm_type: Literal["openai", "bedrock"],
+):
+ internal_messages: List[BaseMessage] = []
+ for tool_call in ai_msg.tool_calls:
+ tool_output = run_tool_call(tool_call, tools)
+ assert isinstance(tool_call["id"], str), "Tool output must have an id."
+ if isinstance(tool_output, dict):
+ tool_message = ToolMultimodalMessage(
+ content=tool_output.get("content", "No response from the tool."),
+ images=tool_output.get("images"),
+ tool_call_id=tool_call["id"],
+ )
+ tool_message = tool_message.postprocess(format=llm_type)
+ else:
+ tool_message = [
+ ToolMessage(content=str(tool_output), tool_call_id=tool_call["id"])
+ ]
+ if isinstance(tool_message, list):
+ internal_messages.extend(tool_message)
+ else:
+ internal_messages.append(tool_message)
+
+ # because we can't answer an aiMessage with an alternating sequence of tool and human messages
+ # we sort the messages by type so that the tool messages are sent first
+ # for more information see implementation of ToolMultimodalMessage.postprocess
+
+ internal_messages.sort(key=lambda x: x.__class__.__name__, reverse=True)
+ messages.extend(internal_messages)
+ return messages
+
+
class SingleMessageGrabber:
def __init__(
self,
diff --git a/src/rai_hmi/rai_hmi/streamlit_hmi_node.py b/src/rai_hmi/rai_hmi/streamlit_hmi_node.py
index 38c85b8b..5b2ed990 100644
--- a/src/rai_hmi/rai_hmi/streamlit_hmi_node.py
+++ b/src/rai_hmi/rai_hmi/streamlit_hmi_node.py
@@ -31,12 +31,12 @@
from std_msgs.msg import String
from std_srvs.srv import Trigger
+from rai.extensions.navigator import RaiNavigator
+from rai.messages import HumanMultimodalMessage, ToolMultimodalMessage
from rai.node import RaiBaseNode
-from rai.scenario_engine.messages import HumanMultimodalMessage, ToolMultimodalMessage
from rai.tools.ros.native import GetCameraImage, Ros2GetTopicsNamesAndTypesTool
from rai_hmi.agent import State as ConversationState
from rai_hmi.agent import create_conversational_agent
-from rai_hmi.custom_mavigator import RaiNavigator
from rai_hmi.task import Task
from rai_interfaces.srv import VectorStoreRetrieval
@@ -147,13 +147,6 @@ def load_documentation(self) -> FAISS:
return rclpy.node.Node("rai_chat_node"), "", None
-llm = ChatOpenAI(
- temperature=0.5,
- model="gpt-4o",
- streaming=True,
-)
-
-
@tool
def add_task_to_queue(task: Task):
"""Use this tool to add a task to the queue. The task will be handled by the executor part of your system."""
@@ -212,60 +205,67 @@ def initialize_genAI(system_prompt: str, _node: Node):
return agent, state
-hmi_node, system_prompt, faiss_index = initialize_ros(package_name)
-agent_executor, state = initialize_genAI(system_prompt=system_prompt, _node=hmi_node)
-
-
-st.subheader("Chat")
-
-if "messages" not in st.session_state:
- st.session_state["messages"] = []
-
-for message in st.session_state["messages"]:
- message_type = message.type
- if isinstance(message, (HumanMultimodalMessage, ToolMessage)):
- message_type = "ai"
- with st.chat_message(message_type):
- if isinstance(message, HumanMultimodalMessage):
- base64_images = [image for image in message.images]
- images = [base64.b64decode(image) for image in base64_images]
- for image in images:
- st.image(image)
- if isinstance(message.content, list):
- content = message.content[0]["text"]
- st.markdown(content)
- elif isinstance(message, (ToolMessage, ToolMultimodalMessage)):
- st.expander(f"Tool: {message.name}").markdown(message.content)
- elif isinstance(message, AIMessage):
- if message.content == "": # tool calling
- for tool_call in message.tool_calls:
- st.markdown(f"Tool: {tool_call['name']}")
- else:
- st.markdown(message.content)
- else:
- st.markdown(message.content)
-
-if prompt := st.chat_input("What is your question?"):
- st.chat_message("user").markdown(prompt)
- st.session_state["messages"].append(HumanMessage(content=prompt))
- state["messages"].append(HumanMessage(content=prompt))
-
- with st.chat_message("assistant"):
- message_placeholder = st.container()
- n_messages = len(state["messages"])
- with message_placeholder.status("Thinking..."):
- response = agent_executor.invoke(state)
- new_messages = state["messages"][n_messages:]
- for message in new_messages:
+if __name__ == "__main__":
+ llm = ChatOpenAI(
+ temperature=0.5,
+ model="gpt-4o",
+ streaming=True,
+ )
+ hmi_node, system_prompt, faiss_index = initialize_ros(package_name)
+ agent_executor, state = initialize_genAI(
+ system_prompt=system_prompt, _node=hmi_node
+ )
+
+ st.subheader("Chat")
+
+ if "messages" not in st.session_state:
+ st.session_state["messages"] = []
+
+ for message in st.session_state["messages"]:
+ message_type = message.type
+ if isinstance(message, (HumanMultimodalMessage, ToolMessage)):
+ message_type = "ai"
+ with st.chat_message(message_type):
if isinstance(message, HumanMultimodalMessage):
base64_images = [image for image in message.images]
- # convert the str to bytes
images = [base64.b64decode(image) for image in base64_images]
for image in images:
- message_placeholder.image(image)
-
- output = response["messages"][-1]
-
- message_placeholder.markdown(output.content)
+ st.image(image)
+ if isinstance(message.content, list):
+ content = message.content[0]["text"]
+ st.markdown(content)
+ elif isinstance(message, (ToolMessage, ToolMultimodalMessage)):
+ st.expander(f"Tool: {message.name}").markdown(message.content)
+ elif isinstance(message, AIMessage):
+ if message.content == "": # tool calling
+ for tool_call in message.tool_calls:
+ st.markdown(f"Tool: {tool_call['name']}")
+ else:
+ st.markdown(message.content)
+ else:
+ st.markdown(message.content)
- st.session_state["messages"].extend(new_messages)
+ if prompt := st.chat_input("What is your question?"):
+ st.chat_message("user").markdown(prompt)
+ st.session_state["messages"].append(HumanMessage(content=prompt))
+ state["messages"].append(HumanMessage(content=prompt))
+
+ with st.chat_message("assistant"):
+ message_placeholder = st.container()
+ n_messages = len(state["messages"])
+ with message_placeholder.status("Thinking..."):
+ response = agent_executor.invoke(state)
+ new_messages = state["messages"][n_messages:]
+ for message in new_messages:
+ if isinstance(message, HumanMultimodalMessage):
+ base64_images = [image for image in message.images]
+ # convert the str to bytes
+ images = [base64.b64decode(image) for image in base64_images]
+ for image in images:
+ message_placeholder.image(image)
+
+ output = response["messages"][-1]
+
+ message_placeholder.markdown(output.content)
+
+ st.session_state["messages"].extend(new_messages)
diff --git a/tests/conftest.py b/tests/conftest.py
index 685deca9..86872ed8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -13,6 +13,9 @@
# limitations under the License.
#
+import glob
+import importlib
+import os
from collections import defaultdict
from typing import Dict
@@ -23,6 +26,18 @@
from rai.config.models import BEDROCK_CLAUDE_HAIKU, OPENAI_LLM, OPENAI_MULTIMODAL
+@pytest.fixture
+def rai_python_modules():
+ packages = glob.glob("src/*")
+ package_names = [os.path.basename(p) for p in packages]
+ ros2_python_packages = []
+ for package_path, package_name in zip(packages, package_names):
+ if os.path.isdir(f"{package_path}/{package_name}"):
+ ros2_python_packages.append(package_name)
+
+ return [importlib.import_module(p) for p in ros2_python_packages]
+
+
@pytest.fixture
def chat_openai_multimodal():
from langchain_openai.chat_models import ChatOpenAI
diff --git a/tests/smoke/import_test.py b/tests/smoke/import_test.py
index 65095a7f..62c1d018 100644
--- a/tests/smoke/import_test.py
+++ b/tests/smoke/import_test.py
@@ -21,8 +21,7 @@
import pytest
-def test_can_import_all_modules_pathlib() -> None:
- import rai
+def test_can_import_all_modules_pathlib(rai_python_modules) -> None:
def import_submodules(package: ModuleType) -> None:
@@ -44,8 +43,14 @@ def import_submodules(package: ModuleType) -> None:
for full_name in sorted(list(importables)):
try:
+ print(f"Importing {full_name}", end=" ")
importlib.import_module(full_name)
+ print("OK")
+
except ImportError as e:
+ print("FAIL")
pytest.fail(f"Failed to import {full_name}: {str(e)}")
- import_submodules(rai)
+ for module in rai_python_modules:
+ print(f"Checking {module}")
+ import_submodules(module)
diff --git a/tests/tools/test_multimodal.py b/tests/tools/test_multimodal.py
index 80356187..bda70afe 100644
--- a/tests/tools/test_multimodal.py
+++ b/tests/tools/test_multimodal.py
@@ -31,8 +31,8 @@
from langfuse.callback import CallbackHandler
from pytest import FixtureRequest
-from rai.scenario_engine.messages import HumanMultimodalMessage
-from rai.scenario_engine.tool_runner import run_requested_tools
+from rai.messages import HumanMultimodalMessage
+from rai.tools.utils import run_requested_tools
class GetImageToolInput(BaseModel):