Skip to content

Commit

Permalink
feat(agents-api,integrations): Working integrations for tool-call step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 25, 2024
1 parent d62a98e commit f13f8dd
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 18 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ LITELLM_REDIS_PASSWORD=<your_litellm_redis_password>
# AGENTS_API_DEBUG=false
# EMBEDDING_MODEL_ID=Alibaba-NLP/gte-large-en-v1.5
# NUM_GPUS=1
# INTEGRATION_SERVICE_URL=http://integrations:8000

# Temporal
# --------
Expand Down
25 changes: 12 additions & 13 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from temporalio import activity

from ..autogen.openapi_model import IntegrationDef
from ..clients import integrations
from ..common.protocol.tasks import StepContext
from ..env import testing
from ..models.tools import get_tool_args_from_metadata
Expand All @@ -24,31 +25,29 @@ async def execute_integration(
developer_id=developer_id, agent_id=agent_id, task_id=task_id
)

arguments = merged_tool_args.get(tool_name, {}) | arguments
arguments = (
merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments
)

try:
if integration.provider == "dummy":
return arguments

else:
raise NotImplementedError(
f"Unknown integration provider: {integration.provider}"
)
return await integrations.run_integration_service(
provider=integration.provider,
setup=integration.setup,
method=integration.method,
arguments=arguments,
)

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in execute_integration: {e}")

raise


async def mock_execute_integration(
context: StepContext,
tool_name: str,
integration: IntegrationDef,
arguments: dict[str, Any],
) -> Any:
return arguments

mock_execute_integration = execute_integration

execute_integration = activity.defn(name="execute_integration")(
execute_integration if not testing else mock_execute_integration
Expand Down
31 changes: 31 additions & 0 deletions agents-api/agents_api/clients/integrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, List

from beartype import beartype
from httpx import AsyncClient

from ..env import integration_service_url

__all__: List[str] = ["run_integration_service"]


@beartype
async def run_integration_service(
*,
provider: str,
arguments: dict,
setup: dict | None = None,
method: str | None = None,
) -> Any:
slug = f"{provider}/{method}" if method else provider
url = f"{integration_service_url}/execute/{slug}"

setup = setup or {}

async with AsyncClient() as client:
response = await client.post(
url,
json={"arguments": arguments, "setup": setup},
)
response.raise_for_status()

return response.json()
7 changes: 7 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@
embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024)


# Integration service
# -------------------
integration_service_url: str = env.str(
"INTEGRATION_SERVICE_URL", default="http://0.0.0.0:8000"
)


# Temporal
# --------
temporal_worker_url: str = env.str("TEMPORAL_WORKER_URL", default="localhost:7233")
Expand Down
14 changes: 11 additions & 3 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ForeachStep,
GetStep,
IfElseWorkflowStep,
IntegrationDef,
LogStep,
MapReduceStep,
ParallelStep,
Expand Down Expand Up @@ -60,7 +61,7 @@

# WorkflowStep = (
# EvaluateStep # ✅
# | ToolCallStep # ❌ <--- high priority
# | ToolCallStep #
# | PromptStep # 🟡 <--- high priority
# | GetStep # ✅
# | SetStep # ✅
Expand Down Expand Up @@ -482,13 +483,20 @@ async def run(
call = tool_call["integration"]
tool_name = call["name"]
arguments = call["arguments"]
integration = next(
integration_spec = next(
(t for t in context.tools if t.name == tool_name), None
)

if integration is None:
if integration_spec is None:
raise ApplicationError(f"Integration {tool_name} not found")

integration = IntegrationDef(
provider=integration_spec.spec["provider"],
setup=integration_spec.spec["setup"],
method=integration_spec.spec["method"],
arguments=arguments,
)

tool_call_response = await workflow.execute_activity(
execute_integration,
args=[context, tool_name, integration, arguments],
Expand Down
1 change: 1 addition & 0 deletions agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ x--shared-environment: &shared-environment
COZO_HOST: ${COZO_HOST:-http://memory-store:9070}
DEBUG: ${AGENTS_API_DEBUG:-False}
EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID:-Alibaba-NLP/gte-large-en-v1.5}
INTEGRATION_SERVICE_URL: ${INTEGRATION_SERVICE_URL:-http://integrations:8000}
LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY}
LITELLM_URL: ${LITELLM_URL:-http://litellm:4000}
SUMMARIZATION_MODEL_NAME: ${SUMMARIZATION_MODEL_NAME:-gpt-4-turbo}
Expand Down
64 changes: 62 additions & 2 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from agents_api.routers.tasks.create_task_execution import start_execution

from .fixtures import cozo_client, test_agent, test_developer_id
from .utils import patch_testing_temporal
from .utils import patch_integration_service, patch_testing_temporal

EMBEDDING_SIZE: int = 1024

Expand Down Expand Up @@ -441,7 +441,7 @@ async def _(
assert result["hello"] == data.input["test"]


@test("workflow: tool call integration type step")
@test("workflow: tool call integration dummy")
async def _(
client=cozo_client,
developer_id=test_developer_id,
Expand Down Expand Up @@ -494,6 +494,65 @@ async def _(
assert result["test"] == data.input["test"]


@test("workflow: tool call integration mocked weather")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"tools": [
{
"type": "integration",
"name": "get_weather",
"integration": {
"provider": "weather",
"setup": {"openweathermap_api_key": "test"},
"arguments": {"test": "fake"},
},
}
],
"main": [
{
"tool": "get_weather",
"arguments": {"location": "_.test"},
},
],
}
),
client=client,
)

expected_output = {"temperature": 20, "humidity": 60}

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
with patch_integration_service(expected_output) as mock_integration_service:
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()
mock_integration_service.assert_called_once()

result = await handle.result()
assert result == expected_output


# FIXME: This test is not working. It gets stuck
# @test("workflow: wait for input step start")
async def _(
Expand Down Expand Up @@ -1026,3 +1085,4 @@ async def _(
mock_run_task_execution_workflow.assert_called_once()

await handle.result()

10 changes: 10 additions & 0 deletions agents-api/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def patch_embed_acompletion(output={"role": "assistant", "content": "Hello, worl
acompletion.return_value = mock_model_response

yield embed, acompletion


@contextmanager
def patch_integration_service(output: dict = {"result": "ok"}):
with patch(
"agents_api.clients.integrations.run_integration_service"
) as run_integration_service:
run_integration_service.return_value = output

yield run_integration_service

0 comments on commit f13f8dd

Please sign in to comment.