Skip to content

Commit

Permalink
feat: refine Serialization implementation
Browse files Browse the repository at this point in the history
-  add StateSerializer abstract class that owns a StateFactory
-  refactor tests, samples and how-tos accordly

work on #29
  • Loading branch information
bsorrentino committed Oct 11, 2024
1 parent 72d0e33 commit 199ae8d
Show file tree
Hide file tree
Showing 17 changed files with 370 additions and 342 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.bsc.langgraph4j.*;
import org.bsc.langgraph4j.langchain4j.serializer.std.ToolExecutionResultMessageSerializer;
import org.bsc.langgraph4j.serializer.Serializer;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppenderChannel;
Expand All @@ -32,7 +33,7 @@ public class AgentExecutor {
public class GraphBuilder {
private ChatLanguageModel chatLanguageModel;
private List<Object> objectsWithTools;
private Serializer<Map<String,Object>> stateSerializer;
private StateSerializer<State> stateSerializer;

public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
this.chatLanguageModel = chatLanguageModel;
Expand All @@ -43,7 +44,7 @@ public GraphBuilder objectsWithTools(List<Object> objectsWithTools) {
return this;
}

public GraphBuilder stateSerializer( Serializer<Map<String,Object>> stateSerializer) {
public GraphBuilder stateSerializer( StateSerializer<State> stateSerializer) {
this.stateSerializer = stateSerializer;
return this;
}
Expand All @@ -62,16 +63,18 @@ public StateGraph<State> build() throws GraphStateException {
.build();

if( stateSerializer == null ) {
var stateSerializer = new ObjectStreamStateSerializer();
stateSerializer.mapper()
var serializer = new ObjectStreamStateSerializer<>(State::new);
serializer.mapper()
.register(IntermediateStep.class, new IntermediateStepSerializer())
.register(AgentAction.class, new AgentActionSerializer())
.register(AgentFinish.class, new AgentFinishSerializer())
.register(AgentOutcome.class, new AgentOutcomeSerializer())
.register(ToolExecutionResultMessage.class, new ToolExecutionResultMessageSerializer());

stateSerializer = serializer;
}

return new StateGraph<>(State.SCHEMA,State::new, stateSerializer)
return new StateGraph<>(State.SCHEMA, stateSerializer)
.addEdge(START,"agent")
.addNode( "agent", node_async( state ->
callAgent(agentRunnable, state))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import lombok.NonNull;
import org.bsc.langgraph4j.agentexecutor.*;
import org.bsc.langgraph4j.serializer.plain_text.PlainTextStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;

import java.io.*;
import java.util.*;
Expand Down Expand Up @@ -140,16 +143,16 @@ public AgentExecutor.State deserialize(JsonParser parser, DeserializationContext
}
}

public class JSONStateSerializer extends PlainTextStateSerializer {
public class JSONStateSerializer extends PlainTextStateSerializer<AgentExecutor.State> {

final ObjectMapper objectMapper;

public static JSONStateSerializer of( ObjectMapper objectMapper ) {
return new JSONStateSerializer(objectMapper);
}

private JSONStateSerializer(ObjectMapper objectMapper) {
Objects.requireNonNull(objectMapper, "objectMapper cannot be null");
private JSONStateSerializer( @NonNull ObjectMapper objectMapper) {
super( AgentExecutor.State::new );
this.objectMapper = objectMapper;

var module = new SimpleModule();
Expand All @@ -169,16 +172,15 @@ public String mimeType() {
}

@Override
public void write(Map<String,Object> object, ObjectOutput out) throws IOException {
var state = new AgentExecutor.State( object );
var json = objectMapper.writeValueAsString(state);
public void write(AgentExecutor.State object, ObjectOutput out) throws IOException {
var json = objectMapper.writeValueAsString(object);
out.writeUTF(json);
}

@Override
public Map<String,Object> read(ObjectInput in) throws IOException, ClassNotFoundException {
public AgentExecutor.State read(ObjectInput in) throws IOException, ClassNotFoundException {
var json = in.readUTF();
return objectMapper.readValue(json, AgentExecutor.State.class).data();
return objectMapper.readValue(json, AgentExecutor.State.class);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ public void jsonSerializeTest() throws Exception {
var state = serializer.read(data);

assertNotNull(state);
assertEquals("perform test twice", state.get("input") );
assertNotNull(state.get("intermediate_steps") );
assertInstanceOf( List.class, state.get("intermediate_steps") );
var intermediateSteps = (List<IntermediateStep>)state.get("intermediate_steps");
assertTrue(state.input().isPresent());
assertEquals("perform test twice", state.input().get() );
assertNotNull(state.intermediateSteps());
assertInstanceOf( List.class, state.intermediateSteps() );
var intermediateSteps = state.intermediateSteps();
assertTrue(intermediateSteps.isEmpty());
assertInstanceOf( AgentOutcome.class, state.get("agent_outcome") );
var agentOutcome = (AgentOutcome)state.get("agent_outcome");
assertTrue( state.agentOutcome().isPresent());
assertInstanceOf( AgentOutcome.class, state.agentOutcome().get() );
var agentOutcome = state.agentOutcome().get();
assertNotNull(agentOutcome);
var action = agentOutcome.action();
assertNotNull(action);
Expand Down Expand Up @@ -89,16 +91,18 @@ public void jsonSerializeTest2() throws Exception {
var state = serializer.read(data);

assertNotNull(state);
assertEquals("perform test another time", state.get("input") );
assertNotNull(state.get("intermediate_steps") );
assertInstanceOf( List.class, state.get("intermediate_steps") );
var intermediateSteps = (List<IntermediateStep>)state.get("intermediate_steps");
assertTrue(state.input().isPresent());
assertEquals("perform test another time", state.input().get() );
assertNotNull(state.intermediateSteps() );
assertInstanceOf( List.class, state.intermediateSteps() );
var intermediateSteps =state.intermediateSteps();
assertEquals(1,intermediateSteps.size());
var intermediateStep = intermediateSteps.get(0);
assertNotNull(intermediateStep);
assertEquals("test tool executed: perform test once", intermediateStep.observation() );
assertInstanceOf( AgentOutcome.class, state.get("agent_outcome") );
var agentOutcome = (AgentOutcome)state.get("agent_outcome");
assertTrue(state.agentOutcome().isPresent());
assertInstanceOf( AgentOutcome.class, state.agentOutcome().get() );
var agentOutcome = state.agentOutcome().get();
assertNotNull(agentOutcome);
var action = agentOutcome.action();
assertNotNull(action);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,7 @@ Map<String,Object> getInitialState(Map<String,Object> inputs, RunnableConfig con
}

State cloneState( Map<String,Object> data ) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {

Map<String,Object> newData = stateGraph.getStateSerializer().cloneObject(data);

return stateGraph.getStateFactory().apply(newData);
return stateGraph.getStateSerializer().cloneObject(data);
}


Expand Down
29 changes: 14 additions & 15 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@


import lombok.Getter;
import lombok.NonNull;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.serializer.Serializer;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;
Expand Down Expand Up @@ -93,32 +95,26 @@ GraphRunnerException exception(String... args) {
private final Map<String, Channel<?>> channels;

@Getter
private final AgentStateFactory<State> stateFactory;

@Getter
private final Serializer<Map<String,Object>> stateSerializer;
private final StateSerializer<State> stateSerializer;

/**
*
* @param channels the state's schema of the graph
* @param stateFactory the factory to create agent states
* @param stateSerializer the serializer to serialize the state
*/
public StateGraph(Map<String, Channel<?>> channels,
AgentStateFactory<State> stateFactory,
Serializer<Map<String,Object>> stateSerializer) {
StateSerializer<State> stateSerializer) {
this.channels = channels;
this.stateFactory = stateFactory;
this.stateSerializer = ( stateSerializer == null ) ? new ObjectStreamStateSerializer() : stateSerializer;
this.stateSerializer = stateSerializer;
}

/**
* Constructs a new StateGraph with the specified state factory.
* Constructs a new StateGraph with the specified serializer.
*
* @param stateFactory the factory to create agent states
* @param stateSerializer the serializer to serialize the state
*/
public StateGraph(AgentStateFactory<State> stateFactory, Serializer<Map<String,Object>> stateSerializer) {
this( mapOf(), stateFactory, stateSerializer );
public StateGraph(@NonNull StateSerializer<State> stateSerializer) {
this( mapOf(), stateSerializer );

}

Expand All @@ -128,7 +124,7 @@ public StateGraph(AgentStateFactory<State> stateFactory, Serializer<Map<String,O
* @param stateFactory the factory to create agent states
*/
public StateGraph(AgentStateFactory<State> stateFactory) {
this( mapOf(), stateFactory, null );
this( mapOf(), stateFactory);

}

Expand All @@ -138,9 +134,12 @@ public StateGraph(AgentStateFactory<State> stateFactory) {
* @param stateFactory the factory to create agent states
*/
public StateGraph(Map<String, Channel<?>> channels, AgentStateFactory<State> stateFactory) {
this( channels, stateFactory, null );
this( channels, new ObjectStreamStateSerializer<>(stateFactory) );
}

public final AgentStateFactory<State> getStateFactory() {
return stateSerializer.stateFactory();
}

public Map<String, Channel<?>> getChannels() {
return unmodifiableMap(channels);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ default String mimeType() {

default byte[] writeObject(T object) throws IOException {
Objects.requireNonNull( object, "object cannot be null" );
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
ObjectOutputStream oas = new ObjectOutputStream(baos);
try( ByteArrayOutputStream stream = new ByteArrayOutputStream() ) {
ObjectOutputStream oas = new ObjectOutputStream(stream);
write(object, oas);
oas.flush();
return baos.toByteArray();
return stream.toByteArray();
}
}

Expand All @@ -27,8 +27,8 @@ default T readObject(byte[] bytes) throws IOException, ClassNotFoundException {
if( bytes.length == 0 ) {
throw new IllegalArgumentException("bytes cannot be empty");
}
try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream(bais);
try( ByteArrayInputStream stream = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream(stream);
return read(ois);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.bsc.langgraph4j.serializer;

import lombok.NonNull;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public abstract class StateSerializer<State extends AgentState> implements Serializer<State> {

private final AgentStateFactory<State> stateFactory;

protected StateSerializer( @NonNull AgentStateFactory<State> stateFactory) {
this.stateFactory = stateFactory;
}

public final AgentStateFactory<State> stateFactory() {
return stateFactory;
}

public final State stateOf( @NonNull Map<String,Object> data) {
return stateFactory.apply(data);
}

public final State cloneObject( @NonNull Map<String,Object> data) throws IOException, ClassNotFoundException {
return cloneObject( stateFactory().apply(data) );
}

}
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
package org.bsc.langgraph4j.serializer.plain_text;

import org.bsc.langgraph4j.serializer.Serializer;
import lombok.NonNull;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;

import java.io.*;
import java.util.Map;

public abstract class PlainTextStateSerializer implements Serializer<Map<String,Object>> {
public abstract class PlainTextStateSerializer<State extends AgentState> extends StateSerializer<State> {

protected PlainTextStateSerializer(@NonNull AgentStateFactory<State> stateFactory) {
super(stateFactory);
}

@Override
public String mimeType() {
return "plain/text";
}

public Map<String,Object> read( String data ) throws IOException, ClassNotFoundException {
public State read( String data ) throws IOException, ClassNotFoundException {
ByteArrayOutputStream bytesStream = new ByteArrayOutputStream();

try(ObjectOutputStream out = new ObjectOutputStream( bytesStream )) {
Expand All @@ -25,7 +32,7 @@ public Map<String,Object> read( String data ) throws IOException, ClassNotFoundE

}

public Map<String,Object> read( Reader reader ) throws IOException, ClassNotFoundException {
public State read( Reader reader ) throws IOException, ClassNotFoundException {
StringBuilder sb = new StringBuilder();
try (BufferedReader bufferedReader = new BufferedReader(reader)) {
String line;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import lombok.extern.slf4j.Slf4j;
import org.bsc.langgraph4j.serializer.Serializer;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.*;

@Slf4j
public class ObjectStreamStateSerializer implements Serializer<Map<String,Object>> {
public class ObjectStreamStateSerializer<State extends AgentState> extends StateSerializer<State> {

static class ListSerializer implements Serializer<List<Object>> {

Expand Down Expand Up @@ -92,8 +95,8 @@ public Map<String, Object> read(ObjectInput in) throws IOException, ClassNotFoun
private final SerializerMapper mapper = new SerializerMapper();
private final MapSerializer mapSerializer = new MapSerializer();

public ObjectStreamStateSerializer() {
super();
public ObjectStreamStateSerializer( AgentStateFactory<State> stateFactory ) {
super(stateFactory);
mapper.register( Collection.class, new ListSerializer() );
mapper.register( Map.class, new MapSerializer() );
}
Expand All @@ -103,12 +106,12 @@ public SerializerMapper mapper() {
}

@Override
public void write(Map<String, Object> object, ObjectOutput out) throws IOException {
mapSerializer.write(object, mapper.objectOutputWithMapper(out));
public void write(State object, ObjectOutput out) throws IOException {
mapSerializer.write(object.data(), mapper.objectOutputWithMapper(out));
}

@Override
public final Map<String, Object> read(ObjectInput in) throws IOException, ClassNotFoundException {
return Collections.unmodifiableMap(mapSerializer.read( mapper.objectOutputWithMapper(in) ));
public final State read(ObjectInput in) throws IOException, ClassNotFoundException {
return stateFactory().apply(mapSerializer.read( mapper.objectOutputWithMapper(in) ));
}
}
Loading

0 comments on commit 199ae8d

Please sign in to comment.