Skip to content

Commit

Permalink
Merge pull request #1076 from RasaHQ/ATO-2131-instrument-ActionExecut…
Browse files Browse the repository at this point in the history
…or._create_api_response

[ATO-2131] Instrument ActionExecutor._create_api_response
  • Loading branch information
Tawakalt authored Feb 16, 2024
2 parents 46436a7 + 32c3b23 commit 2f3b8b7
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 8 deletions.
1 change: 1 addition & 0 deletions changelog/1076.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Instrument `ActionExecutor._create_api_response` and extract `slots`, `events`, `utters` and `message_count` attributes.
31 changes: 30 additions & 1 deletion rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from typing import Any, Dict, Text
from typing import Any, Dict, Text, List
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.types import ActionCall, DomainDict
Expand Down Expand Up @@ -56,3 +56,32 @@ def extract_attrs_for_validation_action(
"slots_to_validate": json.dumps(list(slots_to_validate)),
"action_name": self.name(),
}


def extract_attrs_for_action_executor_create_api_response(
events: List[Dict[Text, Any]],
messages: List[Dict[Text, Any]],
) -> Dict[Text, Any]:
"""Extract the attributes for `ActionExecutor.run`.
:param events: A list of events.
:param messsages: A list of bot responses.
:return: A dictionary containing the attributes.
"""
event_names = []
slot_names = []

for event in events:
event_names.append(event.get("event"))
if event.get("event") == "slot" and event.get("name") != "requested_slot":
slot_names.append(event.get("name"))
utters = [
message.get("response") for message in messages if message.get("response")
]

return {
"events": json.dumps(list(dict.fromkeys(event_names))),
"slots": json.dumps(list(dict.fromkeys(slot_names))),
"utters": json.dumps(utters),
"message_count": len(messages),
}
30 changes: 25 additions & 5 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,28 @@ def traceable(

@functools.wraps(fn)
def wrapper(self: T, *args: Any, **kwargs: Any) -> S:
attrs = (
attr_extractor(self, *args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
# the conditional statement is needed because
# _create_api_response is a static method
if isinstance(self, ActionExecutor) and fn.__name__ == "_create_api_response":
attrs = (
attr_extractor(*args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
else:
attrs = (
attr_extractor(self, *args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
if (
isinstance(self, ActionExecutor)
and fn.__name__ == "_create_api_response"
):
return fn(*args, **kwargs)
return fn(self, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -140,6 +154,12 @@ def instrument(
"run",
attribute_extractors.extract_attrs_for_action_executor,
)
_instrument_method(
tracer,
action_executor_class,
"_create_api_response",
attribute_extractors.extract_attrs_for_action_executor_create_api_response,
)
mark_class_as_instrumented(action_executor_class)
ActionExecutorTracerRegister().register_tracer(tracer)

Expand Down
8 changes: 7 additions & 1 deletion tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from typing import Text
from typing import Any, Dict, Text, List

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
Expand Down Expand Up @@ -47,6 +47,12 @@ def fail_if_undefined(self, method_name: Text) -> None:
async def run(self, action_call: ActionCall) -> None:
pass

@staticmethod
def _create_api_response(
events: List[Dict[Text, Any]], messages: List[Dict[Text, Any]]
) -> None:
pass


class MockValidationAction(ValidationAction):
def __init__(self) -> None:
Expand Down
97 changes: 96 additions & 1 deletion tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,43 @@
import pytest

from typing import Any, Dict, Sequence, Text, Optional
from typing import Any, Dict, Sequence, Text, Optional, List, Callable
from unittest.mock import Mock
from pytest import MonkeyPatch
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry import trace
from rasa_sdk.events import ActionExecuted, SlotSet

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk import Tracker
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister
from rasa_sdk.executor import CollectingDispatcher


def get_dispatcher0():
dispatcher = CollectingDispatcher()
return dispatcher


def get_dispatcher1():
dispatcher = CollectingDispatcher()
dispatcher.utter_message(template="utter_greet")
return dispatcher


def get_dispatcher2():
dispatcher = CollectingDispatcher()
dispatcher.utter_message("Hello")
return dispatcher


def get_dispatcher3():
dispatcher = CollectingDispatcher()
dispatcher.utter_message("Hello")
dispatcher.utter_message(template="utter_greet")
return dispatcher


@pytest.mark.parametrize(
Expand Down Expand Up @@ -88,3 +114,72 @@ def test_instrument_action_executor_run_registers_tracer(

assert tracer is not None
assert tracer == mock_tracer


@pytest.mark.parametrize(
"events, get_dispatcher, expected",
[
(
[],
get_dispatcher0,
{"events": "[]", "slots": "[]", "utters": "[]", "message_count": 0},
),
(
[ActionExecuted("my_form")],
get_dispatcher2,
{"events": '["action"]', "slots": "[]", "utters": "[]", "message_count": 1},
),
(
[ActionExecuted("my_form"), SlotSet("my_slot", "some_value")],
get_dispatcher1,
{
"events": '["action", "slot"]',
"slots": '["my_slot"]',
"utters": '["utter_greet"]',
"message_count": 1,
},
),
(
[SlotSet("my_slot", "some_value")],
get_dispatcher3,
{
"events": '["slot"]',
"slots": '["my_slot"]',
"utters": '["utter_greet"]',
"message_count": 2,
},
),
],
)
def test_tracing_action_executor_create_api_response(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: Optional[List],
get_dispatcher: Callable,
expected: Dict[Text, Any],
) -> None:
component_class = MockActionExecutor

instrumentation.instrument(
tracer_provider,
action_executor_class=component_class,
)

mock_action_executor = component_class()

dispatcher = get_dispatcher()
mock_action_executor._create_api_response(events, dispatcher.messages)

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]

assert captured_span.name == "MockActionExecutor._create_api_response"

assert captured_span.attributes == expected

0 comments on commit 2f3b8b7

Please sign in to comment.