Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate with rai_state_logs #314

Merged
merged 13 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/rosbot-xl-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import rclpy
import rclpy.executors
import rclpy.logging
from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool

from rai.node import RaiStateBasedLlmNode
Expand Down
41 changes: 18 additions & 23 deletions src/rai/rai/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type

import rcl_interfaces.msg
import rclpy
import rclpy.callback_groups
import rclpy.executors
Expand Down Expand Up @@ -52,7 +51,8 @@
from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str
from rai.tools.utils import wait_for_message
from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks
from rai.utils.ros import NodeDiscovery, RosoutBuffer
from rai.utils.ros import NodeDiscovery
from rai.utils.ros_logs import create_logs_parser
from rai_interfaces.action import Task as TaskAction

WHOAMI_SYSTEM_PROMPT_TEMPLATE = """
Expand Down Expand Up @@ -277,7 +277,6 @@ def __init__(
def spin(self):
executor = rclpy.executors.MultiThreadedExecutor()
executor.add_node(self)
executor.add_node(self._async_tool_node)
executor.spin()
rclpy.shutdown()

Expand Down Expand Up @@ -346,6 +345,7 @@ def __init__(
observe_postprocessors: Optional[Dict[str, Callable[[Any], Any]]] = None,
allowlist: Optional[List[str]] = None,
tools: Optional[List[Type[BaseTool]]] = None,
logs_parser_type: Literal["llm", "rai_state_logs"] = "rai_state_logs",
boczekbartek marked this conversation as resolved.
Show resolved Hide resolved
*args,
**kwargs,
):
Expand All @@ -358,13 +358,6 @@ def __init__(

# ---------- ROS configuration ----------
self.callback_group = rclpy.callback_groups.ReentrantCallbackGroup()
self.rosout_sub = self.create_subscription(
rcl_interfaces.msg.Log,
"/rosout",
callback=self.rosout_callback,
callback_group=self.callback_group,
qos_profile=self.qos_profile,
)

# ---------- Robot State ----------
self.robot_state = dict()
Expand Down Expand Up @@ -400,8 +393,18 @@ def __init__(
state_retriever=self.get_robot_state,
logger=self.get_logger(),
)

# We have to use a separate node that we can manually spin for ros-service based
# parser and this node ros-subscriber based parser
logs_parser_node = self if logs_parser_type == "llm" else self._async_tool_node
self.logs_parser = create_logs_parser(
logs_parser_type, logs_parser_node, callback_group=self.callback_group
)
boczekbartek marked this conversation as resolved.
Show resolved Hide resolved
boczekbartek marked this conversation as resolved.
Show resolved Hide resolved
self.simple_llm = get_llm_model(model_type="simple_model")

def summarize_logs(self) -> str:
return self.logs_parser.summarize()

boczekbartek marked this conversation as resolved.
Show resolved Hide resolved
def _initialize_tools(self, tools: List[Type[BaseTool]]):
initialized_tools: List[BaseTool] = list()
for tool_cls in tools:
Expand Down Expand Up @@ -429,8 +432,6 @@ def _initialize_system_prompt(self, prompt: str):
return system_prompt

def _initialize_robot_state_interfaces(self, topics: List[str]):
self.rosout_buffer = RosoutBuffer(get_llm_model(model_type="simple_model"))

for topic in topics:
msg_type = self.get_msg_type(topic)
topic_callback = functools.partial(
Expand Down Expand Up @@ -551,7 +552,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
result.report = report.outcome

self.get_logger().info(f"Finished task:\n{result}")
self.clear_state()

return result
finally:
Expand All @@ -577,7 +577,11 @@ def state_update_callback(self):
state_dict[t] = msg

ts = time.perf_counter()
state_dict["logs_summary"] = self.rosout_buffer.summarize()
try:
state_dict["logs_summary"] = self.summarize_logs()
except Exception as e:
self.get_logger().error(f"Error summarizing logs: {e}")
state_dict["logs_summary"] = ""
te = time.perf_counter() - ts
self.get_logger().info(f"Logs summary retrieved in: {te:.2f}")
self.get_logger().debug(f"{state_dict=}")
Expand All @@ -586,15 +590,6 @@ def state_update_callback(self):
def get_robot_state(self) -> Dict[str, str]:
return self.state_dict

def clear_state(self):
self.rosout_buffer.clear()

def rosout_callback(self, msg: rcl_interfaces.msg.Log):
self.get_logger().debug(f"Received rosout: {msg}")
if "rai_node" in msg.name:
return
self.rosout_buffer.append(f"[{msg.stamp.sec}][{msg.name}]:{msg.msg}")


def describe_ros_image(
msg: sensor_msgs.msg.Image,
Expand Down
42 changes: 1 addition & 41 deletions src/rai/rai/utils/ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,8 @@
# limitations under the License.


from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Optional, Tuple

from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate


class RosoutBuffer:
def __init__(self, llm: BaseChatModel, bufsize: int = 100) -> None:
self.bufsize = bufsize
self._buffer: Deque[str] = deque()
self.template = ChatPromptTemplate.from_messages(
[
(
"system",
"Shorten the following log keeping its format - for example merge similar or repeating lines",
),
("human", "{rosout}"),
]
)
llm = llm
self.llm = self.template | llm

def clear(self):
self._buffer.clear()

def append(self, line: str):
self._buffer.append(line)
if len(self._buffer) > self.bufsize:
self._buffer.popleft()

def get_raw_logs(self, last_n: int = 30) -> str:
return "\n".join(list(self._buffer)[-last_n:])

def summarize(self):
if len(self._buffer) == 0:
return "No logs"
buffer = self.get_raw_logs()
self.clear()
response = self.llm.invoke({"rosout": buffer})
return str(response.content)
from typing import Dict, List, Optional, Tuple


@dataclass
Expand Down
155 changes: 155 additions & 0 deletions src/rai/rai/utils/ros_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import deque
from typing import Deque, Literal, Optional

import rcl_interfaces.msg
import rclpy.callback_groups
import rclpy.node
import rclpy.qos
import rclpy.subscription
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate

import rai_interfaces.srv


class BaseLogsParser:
def summarize(self) -> str:
raise NotImplementedError


def create_logs_parser(
parser_type: Literal["rai_state_logs", "llm"],
node: rclpy.node.Node,
llm: Optional[BaseChatModel] = None,
callback_group: Optional[rclpy.callback_groups.ReentrantCallbackGroup] = None,
bufsize: Optional[int] = 100,
) -> BaseLogsParser:
if parser_type == "rai_state_logs":
return RaiStateLogsParser(node)
elif parser_type == "llm":
if any([v is None for v in [llm, callback_group, bufsize]]):
raise ValueError("Must provide llm, callback_group, and bufsize")
return LlmRosoutParser(llm, node, callback_group, bufsize)
else:
raise ValueError(f"Unknown summarizer type: {parser_type}")


class RaiStateLogsParser(BaseLogsParser):
"""Use rai_state_logs node to get logs"""

SERVICE_NAME = "/get_log_digest"

def __init__(self, node: rclpy.node.Node) -> None:
self.node = node

self.rai_state_logs_client = node.create_client(
rai_interfaces.srv.StringList, self.SERVICE_NAME
)
while not self.rai_state_logs_client.wait_for_service(timeout_sec=1.0):
node.get_logger().info(
f"'{self.SERVICE_NAME}' service is not available, waiting again..."
)

def summarize(self) -> str:
request = rai_interfaces.srv.StringList.Request()
future = self.rai_state_logs_client.call_async(request)
rclpy.spin_until_future_complete(self.node, future)
response: Optional[rai_interfaces.srv.StringList.Response] = future.result()
if response is None or not response.success:
self.node.get_logger().error(f"'{self.SERVICE_NAME}' service call failed")
return ""
Comment on lines +71 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add exception handling for potential service call failures

When calling future.result(), exceptions can occur if the service call fails or times out. Adding error handling ensures that the node remains robust and can handle such scenarios gracefully.

Apply this diff to include exception handling:

     future = self.rai_state_logs_client.call_async(request)
-    rclpy.spin_until_future_complete(self.node, future)
-    response: Optional[rai_interfaces.srv.StringList.Response] = future.result()
+    try:
+        rclpy.spin_until_future_complete(self.node, future)
+        response: Optional[rai_interfaces.srv.StringList.Response] = future.result()
+    except Exception as e:
+        self.node.get_logger().error(f"Service call failed with exception: {e}")
+        return ""
     if response is None or not response.success:
         self.node.get_logger().error(f"'{self.SERVICE_NAME}' service call failed")
         return ""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
rclpy.spin_until_future_complete(self.node, future)
response: Optional[rai_interfaces.srv.StringList.Response] = future.result()
if response is None or not response.success:
self.node.get_logger().error(f"'{self.SERVICE_NAME}' service call failed")
return ""
try:
rclpy.spin_until_future_complete(self.node, future)
response: Optional[rai_interfaces.srv.StringList.Response] = future.result()
except Exception as e:
self.node.get_logger().error(f"Service call failed with exception: {e}")
return ""
if response is None or not response.success:
self.node.get_logger().error(f"'{self.SERVICE_NAME}' service call failed")
return ""

self.node.get_logger().info(
f"'{self.SERVICE_NAME}' service call done. Response: {response.success=}, {response.string_list=}"
)
return "\n".join(response.string_list)


class LlmRosoutParser(BaseLogsParser):
"""Bufferize `/rosout` and summarize is with LLM"""

def __init__(
self,
llm: BaseChatModel,
node: rclpy.node.Node,
callback_group: rclpy.callback_groups.CallbackGroup,
bufsize: int = 100,
):
self.bufsize = bufsize
self._buffer: Deque[str] = deque()
self.template = ChatPromptTemplate.from_messages(
[
(
"system",
"Shorten the following log keeping its format - for example merge similar or repeating lines",
),
("human", "{rosout}"),
]
)
self.llm = self.template | llm

self.node = node

rosout_qos = rclpy.qos.QoSProfile(
reliability=rclpy.qos.ReliabilityPolicy.BEST_EFFORT,
durability=rclpy.qos.DurabilityPolicy.VOLATILE,
depth=10,
)
self.rosout_subscription = self.init_rosout_subscription(
self.node, callback_group, rosout_qos
)

def init_rosout_subscription(
self,
node: rclpy.node.Node,
callback_group: rclpy.callback_groups.CallbackGroup,
qos_profile: rclpy.qos.QoSProfile,
) -> rclpy.subscription.Subscription:
return node.create_subscription(
rcl_interfaces.msg.Log,
"/rosout",
callback=self.rosout_callback,
callback_group=callback_group,
qos_profile=qos_profile,
)

def rosout_callback(self, msg: rcl_interfaces.msg.Log):
self.node.get_logger().debug(f"Received rosout: {msg}")

if "rai_node" in msg.name:
return

self.append(f"[{msg.stamp.sec}][{msg.name}]:{msg.msg}")

def clear(self):
self._buffer.clear()

def append(self, line: str):
self._buffer.append(line)
if len(self._buffer) > self.bufsize:
self._buffer.popleft()
boczekbartek marked this conversation as resolved.
Show resolved Hide resolved

def get_raw_logs(self, last_n: int = 30) -> str:
return "\n".join(list(self._buffer)[-last_n:])

def summarize(self):
if len(self._buffer) == 0:
return "No logs"
buffer = self.get_raw_logs()
self.clear()
response = self.llm.invoke({"rosout": buffer})
return str(response.content)
boczekbartek marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions src/rai_bringup/launch/sim_whoami_demo.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def generate_launch_description():
default_value="src/examples/turtlebot4/allowlist.txt",
description="A list of ros interfaces that are exposed to rai agent.",
),
IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
FindPackageShare("rai_state_logs"),
"/launch/rai_state_logs.launch.py",
]
),
),
IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[FindPackageShare("rai_whoami"), "/launch/rai_whoami.launch.py"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate_launch_description() -> LaunchDescription:
),
DeclareLaunchArgument(
"filters",
default_value=None,
default_value='[""]', # this means that no filters will be applied
description="Filters for logs",
),
Node(
Expand Down
6 changes: 6 additions & 0 deletions src/state_tools/rai_state_logs/src/rai_state_logs_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class LogDigestNode : public rclcpp::Node
max_lines_ = static_cast<uint16_t>(get_parameter("max_lines").as_int());
include_meta_ = get_parameter("include_meta").as_bool();
clear_on_retrieval_ = get_parameter("clear_on_retrieval").as_bool();

// Hack to overcome https://github.com/ros2/rclcpp/issues/1955
if (filters_.size() == 1 && filters_[0].empty()) {
filters_.clear();
}

RCLCPP_INFO(get_logger(), "filters: %s", std::to_string(filters_.size()).c_str());

if (max_lines_ < 1) {
Expand Down