Skip to content

Commit

Permalink
refactor(agent-executor): update actions management
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Nov 26, 2024
1 parent bd1e35f commit ee61b11
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,39 +62,39 @@ public StateSerializer<AgentExecutor.State> object() {

}

class Builder {
class GraphBuilder {
private StreamingChatLanguageModel streamingChatLanguageModel;
private ChatLanguageModel chatLanguageModel;
private final ToolNode.Builder toolNodeBuilder = ToolNode.builder();
private StateSerializer<State> stateSerializer;

public Builder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
this.chatLanguageModel = chatLanguageModel;
return this;
}
public Builder chatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
public GraphBuilder chatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
this.streamingChatLanguageModel = streamingChatLanguageModel;
return this;
}
@Deprecated
public Builder objectsWithTools(List<Object> objectsWithTools) {
public GraphBuilder objectsWithTools(List<Object> objectsWithTools) {
objectsWithTools.forEach(toolNodeBuilder::specification);
return this;
}
public Builder toolSpecification(Object objectsWithTool) {
public GraphBuilder toolSpecification(Object objectsWithTool) {
toolNodeBuilder.specification( objectsWithTool );
return this;
}
public Builder toolSpecification(ToolSpecification spec, ToolExecutor executor) {
public GraphBuilder toolSpecification(ToolSpecification spec, ToolExecutor executor) {
toolNodeBuilder.specification( spec, executor );
return this;
}
public Builder toolSpecification(ToolNode.Specification toolSpecifications) {
public GraphBuilder toolSpecification(ToolNode.Specification toolSpecifications) {
toolNodeBuilder.specification( toolSpecifications );
return this;
}

public Builder stateSerializer(StateSerializer<State> stateSerializer) {
public GraphBuilder stateSerializer(StateSerializer<State> stateSerializer) {
this.stateSerializer = stateSerializer;
return this;
}
Expand All @@ -120,21 +120,25 @@ public StateGraph<State> build() throws GraphStateException {
stateSerializer = Serializers.STD.object();
}

final var callAgent = new CallAgent( agent );
final var executeTools = new ExecuteTools( agent, toolNode );
final var shouldContinue = new ShouldContinue();

return new StateGraph<>(State.SCHEMA, stateSerializer)
.addNode( "agent", CallAgent.of( agent ) )
.addNode( "action", ExecuteTools.of( agent, toolNode ) )
.addNode( "agent", callAgent )
.addNode( "action", executeTools )
.addEdge(START,"agent")
.addConditionalEdges("agent",
ShouldContinue.of(),
shouldContinue,
Map.of("continue", "action", "end", END)
)
.addEdge("action", "agent")
;
}
}

static Builder builder() {
return new Builder();
static GraphBuilder graphBuilder() {
return new GraphBuilder();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,9 @@
@Slf4j
public class CallAgent implements AsyncNodeAction<AgentExecutor.State> {

public static CallAgent of(Agent agent) {
return new CallAgent(agent);
}

final Agent agent;

private CallAgent( Agent agent ) {
public CallAgent( Agent agent ) {
this.agent = agent;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.bsc.langgraph4j.agentexecutor.actions;

import dev.langchain4j.data.message.ToolExecutionResultMessage;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.agentexecutor.Agent;
Expand All @@ -18,14 +19,10 @@
@Slf4j
public class ExecuteTools implements AsyncNodeAction<AgentExecutor.State> {

public static ExecuteTools of(Agent agent, ToolNode toolNode) {
return new ExecuteTools(agent, toolNode);
}

final Agent agent;
final ToolNode toolNode;

private ExecuteTools(Agent agent, ToolNode toolNode) {
public ExecuteTools(@NonNull Agent agent, @NonNull ToolNode toolNode) {
this.agent = agent;
this.toolNode = toolNode;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,6 @@
import static java.util.concurrent.CompletableFuture.completedFuture;

public class ShouldContinue implements AsyncEdgeAction<AgentExecutor.State> {

public static ShouldContinue of() {
return new ShouldContinue();
}

private ShouldContinue() {}

@Override
public CompletableFuture<String> apply(AgentExecutor.State state) {
var shouldContinue = state.agentOutcome()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private StateGraph<AgentExecutor.State> newGraph() throws Exception {
.maxTokens(2000)
.build();

return AgentExecutor.builder()
return AgentExecutor.graphBuilder()
.chatLanguageModel(chatLanguageModel)
.toolSpecification(new TestTool())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private StateGraph<AgentExecutor.State> newGraph() throws Exception {
.maxTokens(2000)
.build();

return AgentExecutor.builder()
return AgentExecutor.graphBuilder()
.chatLanguageModel(chatLanguageModel)
.toolSpecification(new TestTool())
.build();
Expand Down

This file was deleted.

0 comments on commit ee61b11

Please sign in to comment.