Skip to content

Commit

Permalink
refactor: finalize modules
Browse files Browse the repository at this point in the history
jdk8
jdk17 and above
  • Loading branch information
bsorrentino committed Mar 29, 2024
1 parent 2838089 commit 2a94541
Show file tree
Hide file tree
Showing 26 changed files with 1,234 additions and 0 deletions.
115 changes: 115 additions & 0 deletions agents-jdk17/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.bsc.langgraph4j</groupId>
<artifactId>langgraph4j-parent</artifactId>
<version>1.0-SNAPSHOT</version>
</parent>

<artifactId>langgraph4j-agents-jdk17</artifactId>
<packaging>jar</packaging>

<name>langgraph4j::agents::jdk17</name>

<properties>
</properties>

<dependencies>

<dependency>
<groupId>org.bsc.langgraph4j</groupId>
<artifactId>langgraph4j</artifactId>
<version>${project.version}</version>
<classifier>jdk17</classifier>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchai4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchai4j.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<executions>
<execution>
<id>default-package-jdk17</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>jdk17</classifier>
</configuration>
</execution>
</executions>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>

<plugin>
<groupId>org.projectlombok</groupId>
<artifactId>lombok-maven-plugin</artifactId>
<version>1.18.20.0</version>
<configuration>
<sourceDirectory>src/main/java</sourceDirectory>
</configuration>
<!--
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>delombok</goal>
</goals>
</execution>
</executions>
-->
</plugin>
</plugins>
</build>
</project>
48 changes: 48 additions & 0 deletions agents-jdk17/src/main/java/dev/langchain4j/Agent.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.Singular;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

@Builder
public class Agent {

private final ChatLanguageModel chatLanguageModel;
@Singular private final List<ToolSpecification> tools;


public Response<AiMessage> execute( String input, List<IntermediateStep> intermediateSteps ) {
var userMessageTemplate = PromptTemplate.from( "{{input}}" )
.apply( Map.of( "input", input));

var messages = new ArrayList<ChatMessage>();

messages.add(new SystemMessage("You are a helpful assistant"));
messages.add(new UserMessage(userMessageTemplate.text()));

if (!intermediateSteps.isEmpty()) {

var toolRequests = intermediateSteps.stream()
.map(IntermediateStep::action)
.map(AgentAction::toolExecutionRequest)
.toList();

messages.add(new AiMessage(toolRequests)); // reply with tool requests

for (IntermediateStep step : intermediateSteps) {
var toolRequest = step.action().toolExecutionRequest();

messages.add(new ToolExecutionResultMessage(toolRequest.id(), toolRequest.name(), step.observation()));
}
}
return chatLanguageModel.generate( messages, tools );
}
}
11 changes: 11 additions & 0 deletions agents-jdk17/src/main/java/dev/langchain4j/AgentAction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.ToolExecutionRequest;

import java.util.Objects;

record AgentAction (ToolExecutionRequest toolExecutionRequest, String log ) {
public AgentAction {
Objects.requireNonNull(toolExecutionRequest);
}
}
146 changes: 146 additions & 0 deletions agents-jdk17/src/main/java/dev/langchain4j/AgentExecutor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import org.bsc.langgraph4j.GraphState;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.async.AsyncIterator;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.bsc.langgraph4j.GraphState.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;

public class AgentExecutor {

public static class State implements AgentState {

private final Map<String,Object> data;

public State( Map<String,Object> initData ) {
this.data = new HashMap<>(initData);
this.data.putIfAbsent("intermediate_steps",
new AppendableValue<IntermediateStep>());
}

public Map<String,Object> data() {
return Map.copyOf(data);
}

Optional<String> input() {
return value("input");
}
Optional<AgentOutcome> agentOutcome() {
return value("agent_outcome");
}
Optional<List<IntermediateStep>> intermediateSteps() {
return appendableValue("intermediate_steps");
}

@Override
public String toString() {
return data.toString();
}
}

Map<String,Object> runAgent( Agent agentRunnable, State state ) throws Exception {

var input = state.input()
.orElseThrow(() -> new IllegalArgumentException("no input provided!"));

var intermediateSteps = state.intermediateSteps()
.orElseThrow(() -> new IllegalArgumentException("no intermediateSteps provided!"));

var response = agentRunnable.execute( input, intermediateSteps );

if( response.finishReason() == FinishReason.TOOL_EXECUTION ) {

var toolExecutionRequests = response.content().toolExecutionRequests();
var action = new AgentAction( toolExecutionRequests.get(0), "");

return Map.of("agent_outcome", new AgentOutcome( action, null ) );

}
else {
var result = response.content().text();
var finish = new AgentFinish( Map.of("returnValues", result), result );

return Map.of("agent_outcome", new AgentOutcome( null, finish ) );
}

}
Map<String,Object> executeTools( List<ToolInfo> toolInfoList, State state ) throws Exception {

var agentOutcome = state.agentOutcome().orElseThrow(() -> new IllegalArgumentException("no agentOutcome provided!"));

if (agentOutcome.action() == null) {
throw new IllegalStateException("no action provided!" );
}

var toolExecutionRequest = agentOutcome.action().toolExecutionRequest();

var tool = toolInfoList.stream()
.filter( v -> v.specification().name().equals(toolExecutionRequest.name()))
.findFirst()
.orElseThrow(() -> new IllegalStateException("no tool found for: " + toolExecutionRequest.name()));

var result = tool.executor().execute( toolExecutionRequest, null );

return Map.of("intermediate_steps", new IntermediateStep( agentOutcome.action(), result ) );

}

String shouldContinue(State state) {

if (state.agentOutcome().map(AgentOutcome::finish).isPresent()) {
return "end";
}
return "continue";
}

public AsyncIterator<NodeOutput<State>> execute(ChatLanguageModel chatLanguageModel, Map<String, Object> inputs, List<Object> objectsWithTools) throws Exception {


var toolInfoList = ToolInfo.fromList( objectsWithTools );

final List<ToolSpecification> toolSpecifications = toolInfoList.stream()
.map(ToolInfo::specification)
.toList();

var agentRunnable = Agent.builder()
.chatLanguageModel(chatLanguageModel)
.tools( toolSpecifications )
.build();

var workflow = new GraphState<>(State::new);

workflow.setEntryPoint("agent");

workflow.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
);

workflow.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
);

workflow.addConditionalEdge(
"agent",
edge_async(this::shouldContinue),
Map.of("continue", "action", "end", END)
);

workflow.addEdge("action", "agent");

var app = workflow.compile();

return app.stream( inputs );
}
}
7 changes: 7 additions & 0 deletions agents-jdk17/src/main/java/dev/langchain4j/AgentFinish.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package dev.langchain4j;

import java.util.Map;

record AgentFinish (Map<String,Object> returnValues, String log ) {

}
5 changes: 5 additions & 0 deletions agents-jdk17/src/main/java/dev/langchain4j/AgentOutcome.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package dev.langchain4j;

record AgentOutcome(AgentAction action, AgentFinish finish) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package dev.langchain4j;

public record IntermediateStep(AgentAction action, String observation) {
}
43 changes: 43 additions & 0 deletions agents-jdk17/src/main/java/dev/langchain4j/ToolInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.DefaultToolExecutor;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;

public record ToolInfo(ToolSpecification specification, ToolExecutor executor ) {

public ToolInfo {
Objects.requireNonNull(specification);
Objects.requireNonNull(executor);
}

public static List<ToolInfo> of( Object ...objectsWithTools) {
return fromArray( (Object[])objectsWithTools );
}
public static List<ToolInfo> fromArray( Object[] objectsWithTools ) {
List<ToolInfo> toolSpecifications = new ArrayList<>();

for (Object objectWithTools : objectsWithTools) {
for (Method method : objectWithTools.getClass().getDeclaredMethods()) {
if (method.isAnnotationPresent(Tool.class)) {
ToolSpecification toolSpecification = toolSpecificationFrom(method);
ToolExecutor executor = new DefaultToolExecutor(objectWithTools, method);
toolSpecifications.add( new ToolInfo( toolSpecification, executor));
}
}
}
return List.copyOf(toolSpecifications);
}
public static List<ToolInfo> fromList(List<Object> objectsWithTools ) {
return fromArray(objectsWithTools.toArray());
}

}
4 changes: 4 additions & 0 deletions agents-jdk17/src/main/java/resources/logging.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
handlers=java.util.logging.ConsoleHandler
.level=FINE
java.util.logging.ConsoleHandler.level=ALL
java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter
Loading

0 comments on commit 2a94541

Please sign in to comment.