Skip to content

Commit

Permalink
feat(server): add resume management
Browse files Browse the repository at this point in the history
work on #34
  • Loading branch information
bsorrentino committed Oct 4, 2024
1 parent 1465519 commit bf030a7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 16 deletions.
6 changes: 3 additions & 3 deletions server-jetty/logging.properties
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
handlers=java.util.logging.ConsoleHandler
.level=INFO
AdaptiveRag.level=FINE
java.util.logging.ConsoleHandler.level=ALL
java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter


.level=INFO
AdaptiveRag.level=FINE
org.bsc.langgraph4j.CompiledGraph.level = FINEST
org.bsc.langgraph4j.LangGraphStreamingServer.level = FINEST
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.MemorySaver;
import org.bsc.langgraph4j.state.AgentState;
Expand All @@ -28,9 +29,12 @@
import java.io.PrintWriter;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static java.util.Optional.ofNullable;


/**
* LangGraphStreamingServer is an interface that represents a server that supports streaming
Expand Down Expand Up @@ -223,29 +227,72 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
var session = request.getSession(true);
Objects.requireNonNull(session, "session cannot be null");

var threadId = request.getParameter("thread");
Objects.requireNonNull(threadId, "thread cannot be null");
var threadId = ofNullable(request.getParameter("thread"))
.orElseThrow(() -> new IllegalStateException("Missing thread id!"));

var resume = ofNullable(request.getParameter("resume"))
.map(Boolean::parseBoolean).orElse(false);


final PrintWriter writer = response.getWriter();

Map<String, Object> dataMap = objectMapper.readValue(request.getInputStream(), new TypeReference<Map<String, Object>>() {
});

// Start asynchronous processing
var asyncContext = request.startAsync();

try {

var config = new PersistentConfig( session.getId(), threadId);
var dataMap = objectMapper.readValue(request.getInputStream(), new TypeReference<Map<String, Object>>() {
});

AsyncGenerator<? extends NodeOutput<? extends AgentState>> generator = null;

var persistentConfig = new PersistentConfig(session.getId(), threadId);

var compiledGraph = graphCache.get(persistentConfig);

if( resume ) {

log.trace( "RESUME REQUEST PREPARE" );

if (compiledGraph == null) {
throw new IllegalStateException( "Missing CompiledGraph in session!" );
}

var checkpointId = ofNullable(request.getParameter("checkpoint"))
.orElseThrow(() -> new IllegalStateException("Missing checkpoint id!"));

var config = RunnableConfig.builder()
.threadId(threadId)
.checkPointId(checkpointId)
.build();

var stateSnapshot = compiledGraph.getState(config);

config = stateSnapshot.config();

log.trace( "RESUME UPDATE STATE USING CONFIG {}\n{}", config, dataMap);

config = compiledGraph.updateState(config, dataMap );

log.trace( "RESUME REQUEST STREAM {}", config);

generator = compiledGraph.streamSnapshots(null, config);


var compiledGraph = graphCache.get(config);
if( compiledGraph == null ) {
compiledGraph = stateGraph.compile( compileConfig(config) );
graphCache.put( config, compiledGraph );
}
else {

compiledGraph.streamSnapshots(dataMap, runnableConfig(config) )
.forEachAsync(s -> {

if (compiledGraph == null) {
compiledGraph = stateGraph.compile(compileConfig(persistentConfig));
graphCache.put(persistentConfig, compiledGraph);
}

generator = compiledGraph.streamSnapshots(dataMap, runnableConfig(persistentConfig));
}

generator.forEachAsync(s -> {
try {
try {
writer.printf("[ \"%s\",", threadId);
Expand All @@ -259,15 +306,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
writer.flush();
TimeUnit.SECONDS.sleep(1);
} catch ( InterruptedException e) {
throw new RuntimeException(e);
throw new CompletionException(e);
}

})
.thenAccept(v -> writer.close())
.thenAccept(v -> asyncContext.complete())
;

} catch (Exception e) {
} catch (Throwable e) {
log.error("Error streaming", e);
throw new ServletException(e);
}
}
Expand Down

0 comments on commit bf030a7

Please sign in to comment.