-
Notifications
You must be signed in to change notification settings - Fork 15
/
visualization.py
88 lines (73 loc) · 3.2 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import pandas as pd
from matplotlib import pyplot as plt
import json
import networkx as nx
# Load data
file_path_latest = 'log/multiagent_data_20241030.csv'
task_id = 3
meta_task_id = 'master_agent_task'
data_latest = pd.read_csv(file_path_latest)
# Filter data
task_data = data_latest[(data_latest['task_id'] == task_id) &
(data_latest['meta_task_id'] == meta_task_id)]
# Initialize graph and add root nodes
G = nx.MultiDiGraph()
G.add_node('user_request')
G.add_node('main_agent')
G.add_edge('user_request', 'main_agent', key=0, label='step 0: user prompting main_agent')
# Helper to map agent to its parent node
parent_mapping = {'main_agent': 'user_request'}
# Build graph based on agent interactions
step = 1
for _, row in task_data.iterrows():
agent = row['agent']
depth = row['depth']
response = json.loads(row['response'])
tool_calls = response.get('tool_calls', None)
if tool_calls:
for tool_call in tool_calls:
tool = tool_call['function']['name']
is_agent = 'agent' in tool
action = f'calling {"sub agent" if is_agent else "tool"}: {tool}'
# Add formatted data
if is_agent:
parent_mapping[tool] = agent
G.add_edge(agent, tool, key=step, label=f'step {step}: {action}')
step += 1
else:
G.add_edge(agent, parent_mapping[agent], key=step, label=f'step {step}: sub agent {agent} going back to parent agent {parent_mapping[agent]}')
step += 1
print(G.edges)
def plot_hierarchical_multi_edge_graph(G, output_file=f'log/agent_interaction_graph{task_id}.png'):
plt.figure(figsize=(12, 8))
pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
print(pos)
# Draw nodes
colors = ['lightgreen' if n == 'user_request'
else 'lightyellow' if n == 'main_agent'
else 'lightblue' for n in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=3000)
nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold')
# Draw directed edges with labels
edge_labels = {(u, v, k): d['label'] for u, v, k, d in G.edges(keys=True, data=True)}
for (u, v, key), label in edge_labels.items():
print(u, v, key)
rad = 0.15 + 0.05 * key * (-1 if key % 2 == 0 else 1)
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], connectionstyle=f'arc3,rad={rad}',
edge_color='gray', arrows=True, arrowsize=15, arrowstyle='->', width=1)
x, y = pos[u][0] * 0.6 + pos[v][0] * 0.4, pos[u][1] * 0.6 + pos[v][1] * 0.4
num_edges = len([k for (a, b, k) in edge_labels.keys() if a == u and b == v])
if num_edges > 1:
offset = (key - (num_edges - 1) / 2) * 0.05
x += offset * (pos[v][0] - pos[u][0])
y += offset * (pos[v][1] - pos[u][1])
plt.text(x, y, label, fontsize=6, ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7))
plt.axis('off')
plt.tight_layout()
plt.savefig(output_file, format='PNG', dpi=300, bbox_inches='tight')
plt.close()
print(f"Graph saved to {output_file}")
# Usage:
plot_hierarchical_multi_edge_graph(G)
## TODO: pass hierarchies
## TODO: arrow text