Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Mar 4, 2024
1 parent 3360040 commit 1656808
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
13 changes: 13 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,16 @@ def run(
domain: DomainDict,
) -> List[Dict[Text, Any]]:
raise Exception("test exception")


class CustomActionWithDialogueStack(Action):
def name(cls) -> Text:
return "custom_action_with_dialogue_stack"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("stack", tracker.stack)]
39 changes: 38 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, List, Text
import json
import logging
import zlib
Expand All @@ -14,6 +15,19 @@
logger = logging.getLogger(__name__)


def get_stack():
dialogue_stack = [
{
"frame_id": "CP6JP9GQ",
"flow_id": "check_balance",
"step_id": "0_check_balance",
"frame_type": "regular",
"type": "flow",
}
]
return dialogue_stack


def test_server_health_returns_200():
request, response = app.test_client.get("/health")
assert response.status == 200
Expand All @@ -23,14 +37,15 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 5
assert len(response.json) == 6

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
# defined in tests/test_actions.py
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
{"name": "custom_action_with_dialogue_stack"},
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
{"name": "mock_form_validation_action"},
Expand Down Expand Up @@ -119,6 +134,28 @@ def test_server_webhook_custom_action_encoded_data_returns_200():
assert response.status == 200


@pytest.mark.parametrize(
"stack_state, dialogue_stack",
[
({}, []),
({"stack": get_stack()}, get_stack()),
],
)
def test_server_webhook_custom_action_with_dialogue_stack_returns_200(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]]
):
data = {
"next_action": "custom_action_with_dialogue_stack",
"tracker": {"sender_id": "1", "conversation_id": "default", **stack_state},
}
_, response = app.test_client.post("/webhook", data=json.dumps(data))
print("*********** ", response.json)
events = response.json.get("events")

assert events == [SlotSet("stack", dialogue_stack)]
assert response.status == 200


# ENSURE THIS IS ALWAYS THE LAST TEST FOR OTHER TESTS TO RUN
# because the call to sys.exit() terminates pytest process
def test_endpoint_exit_for_unknown_actions_package():
Expand Down
6 changes: 4 additions & 2 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Any, Dict, List, Text

import pytest
from rasa_sdk.events import SlotSet
Expand Down Expand Up @@ -83,7 +83,9 @@ def test_tracker_with_slots():
({"stack": get_stack()}, get_stack()),
],
)
def test_stack_in_tracker_state(stack_state, dialogue_stack):
def test_stack_in_tracker_state(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]]
):

state = {"events": [], "sender_id": "old", "active_loop": {}, **stack_state}
tracker = Tracker.from_dict(state)
Expand Down

0 comments on commit 1656808

Please sign in to comment.