Skip to content

Commit

Permalink
feat(agentexecutor): add JSON serialization support
Browse files Browse the repository at this point in the history
work on #34
  • Loading branch information
bsorrentino committed Oct 6, 2024
1 parent 3fa5ec0 commit 16d0179
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 3 deletions.
9 changes: 8 additions & 1 deletion agent-executor/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<jackson.version>2.17.2</jackson.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -56,6 +57,12 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
Expand All @@ -79,7 +86,7 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
<skipTests>false</skipTests>
</configuration>
</plugin>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.bsc.langgraph4j.agentexecutor;

import dev.langchain4j.agent.tool.ToolSpecification;
import lombok.extern.slf4j.Slf4j;
import org.bsc.langgraph4j.agentexecutor.serializer.AgentActionSerializer;
import org.bsc.langgraph4j.agentexecutor.serializer.AgentFinishSerializer;
import org.bsc.langgraph4j.agentexecutor.serializer.AgentOutcomeSerializer;
Expand All @@ -23,6 +24,7 @@
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import org.bsc.langgraph4j.langchain4j.tool.ToolNode;

@Slf4j
public class AgentExecutor {

public class GraphBuilder {
Expand Down Expand Up @@ -105,7 +107,7 @@ List<IntermediateStep> intermediateSteps() {
}

Map<String,Object> callAgent(Agent agentRunnable, State state ) {

log.trace( "callAgent" );
var input = state.input()
.orElseThrow(() -> new IllegalArgumentException("no input provided!"));

Expand All @@ -131,6 +133,7 @@ Map<String,Object> callAgent(Agent agentRunnable, State state ) {
}

Map<String,Object> executeTools( ToolNode toolNode, State state ) {
log.trace( "executeTools" );

var agentOutcome = state.agentOutcome().orElseThrow(() -> new IllegalArgumentException("no agentOutcome provided!"));

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,174 @@
package org.bsc.langgraph4j.agentexecutor.serializer;

public class JSONStateSerializer {
import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import org.bsc.langgraph4j.agentexecutor.*;
import org.bsc.langgraph4j.serializer.Serializer;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;


class IntermediateStepDeserializer extends JsonDeserializer<IntermediateStep> {

@Override
public IntermediateStep deserialize(JsonParser parser, DeserializationContext ctx) throws IOException, JacksonException {
JsonNode node = parser.getCodec().readTree(parser);
var actionNode = node.get("action");
var action = ( actionNode != null && !actionNode.isNull()) ?
ctx.readValue(actionNode.traverse(parser.getCodec()), AgentAction.class) :
null;

return new IntermediateStep(action, node.get("observation").asText());
}
}

class ToolExecutionRequestDeserializer extends JsonDeserializer<ToolExecutionRequest> {

@Override
public ToolExecutionRequest deserialize(JsonParser parser, DeserializationContext ctx) throws IOException, JacksonException {
JsonNode node = parser.getCodec().readTree(parser);
return ToolExecutionRequest.builder()
.id(node.get("id").asText())
.name(node.get("name").asText())
.arguments(node.get("arguments").asText())
.build();
}
}

class AgentActionDeserializer extends JsonDeserializer<AgentAction> {

@Override
public AgentAction deserialize(JsonParser parser, DeserializationContext ctx) throws IOException, JacksonException {
JsonNode node = parser.getCodec().readTree(parser);

var toolExecutionRequestNode = node.get("toolExecutionRequest");
var toolExecutionRequest = ctx.readValue(toolExecutionRequestNode.traverse(parser.getCodec()), ToolExecutionRequest.class);

return new AgentAction(
toolExecutionRequest,
node.get("log").asText()
);
}
}

class AgentFinishDeserializer extends JsonDeserializer<AgentFinish> {

@Override
public AgentFinish deserialize(JsonParser parser, DeserializationContext ctx) throws IOException, JacksonException {
JsonNode node = parser.getCodec().readTree(parser);
var returnValuesNode = node.get("returnValues");
var returnValues = ctx.readValue(returnValuesNode.traverse(parser.getCodec()), Map.class);
var log = node.get("log").asText();
return new AgentFinish(returnValues, log);
}
}

class AgentOutcomeDeserializer extends JsonDeserializer<AgentOutcome> {

@Override
public AgentOutcome deserialize(JsonParser parser, DeserializationContext ctx) throws IOException, JacksonException {
JsonNode node = parser.getCodec().readTree(parser);

var actionNode = node.get("action");
var action = ( actionNode != null && !actionNode.isNull()) ?
ctx.readValue(actionNode.traverse(parser.getCodec()), AgentAction.class) :
null;

var finishNode = node.get("finish");
var finish = ( finishNode != null && !finishNode.isNull()) ?
ctx.readValue(finishNode.traverse(parser.getCodec()), AgentFinish.class) :
null;

return new AgentOutcome( action, finish );
}
}

class StateDeserializer extends JsonDeserializer<AgentExecutor.State> {

@Override
public AgentExecutor.State deserialize(JsonParser parser, DeserializationContext ctx) throws IOException, JsonProcessingException {
JsonNode node = parser.getCodec().readTree(parser);

Map<String,Object> data = new HashMap<>();

data.put( "input", node.get("input").asText() );

var intermediateStepsNode = node.get("intermediate_steps");

if( intermediateStepsNode == null || intermediateStepsNode.isNull() ) { // GUARD
throw new IOException("intermediate_steps must not be null!");
}
if( !intermediateStepsNode.isArray()) { // GUARD
throw new IOException("intermediate_steps must be an array!");
}
var intermediateStepList = new ArrayList<IntermediateStep>();
for (JsonNode intermediateStepNode : intermediateStepsNode) {

var intermediateStep = ctx.readValue(intermediateStepNode.traverse(parser.getCodec()), IntermediateStep.class);
intermediateStepList.add(intermediateStep); // intermediateStepList
}
data.put("intermediate_steps", intermediateStepList);

var agentOutcomeNode = node.get("agent_outcome");
var agentOutcome = ctx.readValue(agentOutcomeNode.traverse(parser.getCodec()), AgentOutcome.class);

data.put("agent_outcome", agentOutcome);

return new AgentExecutor.State( data );
}
}

public class JSONStateSerializer implements Serializer<Map<String,Object>> {

final ObjectMapper objectMapper;

public static JSONStateSerializer of( ObjectMapper objectMapper ) {
return new JSONStateSerializer(objectMapper);
}

private JSONStateSerializer(ObjectMapper objectMapper) {
Objects.requireNonNull(objectMapper, "objectMapper cannot be null");
this.objectMapper = objectMapper;

var module = new SimpleModule();
module.addDeserializer(AgentExecutor.State.class, new StateDeserializer());
module.addDeserializer(AgentOutcome.class, new AgentOutcomeDeserializer());
module.addDeserializer(AgentAction.class, new AgentActionDeserializer());
module.addDeserializer(AgentFinish.class, new AgentFinishDeserializer());
module.addDeserializer(ToolExecutionRequest.class, new ToolExecutionRequestDeserializer());
module.addDeserializer(IntermediateStep.class, new IntermediateStepDeserializer());

objectMapper.registerModule(module);
}

@Override
public String mimeType() {
return "application/json";
}

@Override
public void write(Map<String,Object> object, ObjectOutput out) throws IOException {
var state = new AgentExecutor.State( object );
var json = objectMapper.writeValueAsString(state);
out.writeUTF(json);
}

@Override
public Map<String,Object> read(ObjectInput in) throws IOException, ClassNotFoundException {
var json = in.readUTF();
System.out.println( json );
return objectMapper.readValue(json, AgentExecutor.State.class).data();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.bsc.langgraph4j.checkpoint.MemorySaver;
import org.bsc.langgraph4j.state.AgentState;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.List;
Expand All @@ -19,6 +20,7 @@

import static org.junit.jupiter.api.Assertions.*;

@Disabled
public class AgentExecutorTest {

@BeforeAll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import dev.langchain4j.model.output.FinishReason;
import org.bsc.langgraph4j.langchain4j.tool.ToolNode;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.*;

@Disabled
public class AgentTest {

@BeforeAll
Expand Down
Loading

0 comments on commit 16d0179

Please sign in to comment.