Skip to content

Commit

Permalink
feat(CompiledGraph): add streamSnapshots() method
Browse files Browse the repository at this point in the history
work on #24
  • Loading branch information
bsorrentino committed Sep 16, 2024
1 parent bebf5c0 commit 11fc73b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 20 deletions.
23 changes: 23 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -439,4 +439,27 @@ public Data<Output> next() {
}
}

/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> streamSnapshots( Map<String,Object> inputs, RunnableConfig config ) throws Exception {
Objects.requireNonNull(config, "config cannot be null");

RunnableConfig newConfig = new RunnableConfig(config) {

@Override
public StreamMode streamMode() {
return StreamMode.SNAPSHOTS;
}
};

return new AsyncNodeGenerator<>( inputs, newConfig );
}


}
17 changes: 5 additions & 12 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/RunnableConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,21 @@

@ToString
public class RunnableConfig {
private CompiledGraph.StreamMode streamMode = CompiledGraph.StreamMode.VALUES;
private String threadId;
private String checkPointId;
private String nextNode;


public CompiledGraph.StreamMode streamMode() {
return streamMode;
return CompiledGraph.StreamMode.VALUES;
}

public Optional<String> threadId() {
public final Optional<String> threadId() {
return Optional.ofNullable(threadId);
}
public Optional<String> checkPointId() {
public final Optional<String> checkPointId() {
return Optional.ofNullable(checkPointId);
}
public Optional<String> nextNode() {
public final Optional<String> nextNode() {
return Optional.ofNullable(nextNode);
}

Expand All @@ -41,10 +39,6 @@ public static class Builder {
Builder( RunnableConfig config ) {
this.config = new RunnableConfig(config);
}
public Builder streamMode(CompiledGraph.StreamMode streamMode) {
this.config.streamMode = streamMode;
return this;
}
public Builder threadId(String threadId) {
this.config.threadId = threadId;
return this;
Expand All @@ -63,12 +57,11 @@ public RunnableConfig build() {
}
}

private RunnableConfig( RunnableConfig config ) {
protected RunnableConfig( RunnableConfig config ) {
Objects.requireNonNull( config, "config cannot be null" );
this.threadId = config.threadId;
this.checkPointId = config.checkPointId;
this.nextNode = config.nextNode;
this.streamMode = config.streamMode;
}
private RunnableConfig() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,10 @@ public void testViewAndUpdatePastGraphState() throws Exception {
Map<String, Object> inputs = mapOf( "messages", "whether in Naples?" );

var runnableConfig = RunnableConfig.builder()
.streamMode( CompiledGraph.StreamMode.SNAPSHOTS )
.threadId("thread_1")
.build();

var results = app.stream( inputs, runnableConfig ).stream().collect( Collectors.toList() );
var results = app.streamSnapshots( inputs, runnableConfig ).stream().collect( Collectors.toList() );

results.forEach( r -> log.info( "{}: Node: {} - {}", r.getClass().getSimpleName(), r.node(), r.state().messages() ) );

Expand Down
3 changes: 3 additions & 0 deletions server-jetty/logging.properties
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ handlers=java.util.logging.ConsoleHandler
AdaptiveRag.level=FINE
java.util.logging.ConsoleHandler.level=ALL
java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter


org.bsc.langgraph4j.LangGraphStreamingServer.level = FINEST
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.ResourceHandler;
import org.eclipse.jetty.util.resource.Resource;
import org.eclipse.jetty.util.resource.ResourceFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -36,7 +35,8 @@
*/
public interface LangGraphStreamingServer {

Logger log = LoggerFactory.getLogger(LangGraphStreamingServer.class);

static Logger log = LoggerFactory.getLogger(LangGraphStreamingServer.class);

CompletableFuture<Void> start() throws Exception;

Expand Down Expand Up @@ -105,7 +105,7 @@ public <State extends AgentState> LangGraphStreamingServer build(StateGraph<Stat

// context.setContextPath("/");
// Add the streaming servlet
context.addServlet(new ServletHolder(new GraphExecutionServlet<State>(stateGraph, objectMapper)), "/stream");
context.addServlet(new ServletHolder(new GraphStreamServlet<State>(stateGraph, objectMapper)), "/stream");

var handlerList = new Handler.Sequence( resourceHandler, context);

Expand Down Expand Up @@ -136,15 +136,15 @@ record PersistentConfig(String sessionId, String threadId) {

}

class GraphExecutionServlet<State extends AgentState> extends HttpServlet {
class GraphStreamServlet<State extends AgentState> extends HttpServlet {
Logger log = LangGraphStreamingServer.log;

final StateGraph<State> stateGraph;
final ObjectMapper objectMapper;
final MemorySaver saver = new MemorySaver();
final Map<PersistentConfig, CompiledGraph<State>> graphCache = new HashMap<>();

public GraphExecutionServlet(StateGraph<State> stateGraph, ObjectMapper objectMapper) {
public GraphStreamServlet(StateGraph<State> stateGraph, ObjectMapper objectMapper) {
Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
this.stateGraph = stateGraph;
this.objectMapper = objectMapper;
Expand Down Expand Up @@ -191,9 +191,10 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
graphCache.put( config, compiledGraph );
}

compiledGraph.stream(dataMap)
compiledGraph.streamSnapshots(dataMap, runnableConfig(config) )
.forEachAsync(s -> {
try {
LangGraphStreamingServer.log.trace("{}", s);

writer.print("{");
writer.printf("\"node\": \"%s\"", s.node());
Expand Down

0 comments on commit 11fc73b

Please sign in to comment.