Skip to content

Commit

Permalink
feat(core/CompiledGraph.java): enhance subgraph processing in state g…
Browse files Browse the repository at this point in the history
…raph

It ensures that interruptions (nodes marked as "before") are correctly redirected to the real target ID after subgraph expansion.

work on #73
  • Loading branch information
bsorrentino committed Feb 11, 2025
1 parent b267d71 commit 00679e9
Showing 1 changed file with 51 additions and 26 deletions.
77 changes: 51 additions & 26 deletions core/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,37 @@ public enum StreamMode {
*/
protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfig ) throws GraphStateException {
this.stateGraph = stateGraph;
this.compileConfig = compileConfig;

var stateGraphNodesAndEdges = StateGraphNodesAndEdges.process( stateGraph );
var stateGraphNodesEdgesAndConfig = StateGraphNodesEdgesAndConfig.process( stateGraph, compileConfig );


// CHECK INTERRUPTIONS
for (String interruption : compileConfig.getInterruptBefore() ) {
if (!stateGraphNodesAndEdges.nodes().anyMatchById( interruption )) {
for (String interruption : stateGraphNodesEdgesAndConfig.interruptsBefore() ) {
if (!stateGraphNodesEdgesAndConfig.nodes().anyMatchById( interruption )) {
throw StateGraph.Errors.interruptionNodeNotExist.exception(interruption);
}
}
for (String interruption : compileConfig.getInterruptBefore() ) {
if (!stateGraphNodesAndEdges.nodes().anyMatchById( interruption )) {
for (String interruption : stateGraphNodesEdgesAndConfig.interruptsBefore() ) {
if (!stateGraphNodesEdgesAndConfig.nodes().anyMatchById( interruption )) {
throw StateGraph.Errors.interruptionNodeNotExist.exception(interruption);
}
}

// RE-CREATE THE EVENTUALLY UPDATED COMPILE CONFIG
this.compileConfig = CompileConfig.builder(compileConfig)
.interruptsBefore(stateGraphNodesEdgesAndConfig.interruptsBefore())
.interruptsAfter(stateGraphNodesEdgesAndConfig.interruptsAfter())
.build();

// EVALUATES NODES
for (var n : stateGraphNodesAndEdges.nodes().elements ) {
for (var n : stateGraphNodesEdgesAndConfig.nodes().elements ) {
var factory = n.actionFactory();
Objects.requireNonNull(factory, format("action factory for node id '%s' is null!", n.id()));
nodes.put(n.id(), factory.apply(compileConfig));
}

// EVALUATE EDGES
for( var e : stateGraphNodesAndEdges.edges().elements ) {
for( var e : stateGraphNodesEdgesAndConfig.edges().elements ) {
var targets = e.targets();
if (targets.size() == 1) {
edges.put(e.sourceId(), targets.get(0));
Expand All @@ -86,9 +92,9 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi

var parallelNodeEdges = parallelNodeStream.get()
.map( target -> new Edge<State>(target.id()))
.filter( ee -> stateGraphNodesAndEdges.edges().elements.contains( ee ) )
.map( ee -> stateGraphNodesAndEdges.edges().elements.indexOf( ee ) )
.map( index -> stateGraphNodesAndEdges.edges().elements.get(index) )
.filter( ee -> stateGraphNodesEdgesAndConfig.edges().elements.contains( ee ) )
.map( ee -> stateGraphNodesEdgesAndConfig.edges().elements.indexOf( ee ) )
.map( index -> stateGraphNodesEdgesAndConfig.edges().elements.get(index) )
.toList();

var parallelNodeTargets = parallelNodeEdges.stream()
Expand Down Expand Up @@ -613,23 +619,33 @@ public Data<Output> next() {

}

record StateGraphNodesAndEdges<State extends AgentState>(StateGraph.Nodes<State> nodes, StateGraph.Edges<State> edges ) {

StateGraphNodesAndEdges( StateGraph<State> stateGraph ) {
this( stateGraph.nodes, stateGraph.edges );
record StateGraphNodesEdgesAndConfig<State extends AgentState>(
StateGraph.Nodes<State> nodes,
StateGraph.Edges<State> edges,
List<String> interruptsBefore,
List<String> interruptsAfter) {

StateGraphNodesEdgesAndConfig(StateGraph<State> stateGraph, CompileConfig config) {
this( stateGraph.nodes,
stateGraph.edges,
config.interruptsBefore(),
config.interruptsAfter() );
}

static <State extends AgentState> StateGraphNodesAndEdges<State> process(StateGraph<State> stateGraph ) throws GraphStateException {
static <State extends AgentState> StateGraphNodesEdgesAndConfig<State> process(StateGraph<State> stateGraph, CompileConfig config ) throws GraphStateException {

var subgraphNodes = stateGraph.nodes.onlySubStateGraphNodes();

if( subgraphNodes.isEmpty() ) {
return new StateGraphNodesAndEdges<>( stateGraph );
return new StateGraphNodesEdgesAndConfig<>( stateGraph, config );
}

var result = new StateGraphNodesAndEdges<>(
new StateGraph.Nodes<>( stateGraph.nodes.exceptSubStateGraphNodes() ),
new StateGraph.Edges<>( stateGraph.edges.elements) );
var result = new StateGraphNodesEdgesAndConfig<>(
new StateGraph.Nodes<>( stateGraph.nodes.exceptSubStateGraphNodes() ),
new StateGraph.Edges<>( stateGraph.edges.elements),
new ArrayList<>(config.interruptsBefore()),
new ArrayList<>(config.interruptsAfter()) );


for( var subgraphNode : subgraphNodes ) {

Expand All @@ -642,6 +658,21 @@ static <State extends AgentState> StateGraphNodesAndEdges<State> process(StateGr
throw new GraphStateException( "subgraph not support start with parallel branches yet!" );
}

var sgEdgeStartTarget = sgEdgeStart.target();

if( sgEdgeStartTarget.id() == null ) {
throw new GraphStateException( format("the target for node '%s' is null!", subgraphNode.id()) );
}

var sgEdgeStartRealTargetId = subgraphNode.formatId( sgEdgeStartTarget.id() );

// Process Interruption (Before)
result.interruptsBefore().replaceAll( interrupt ->
Objects.equals( subgraphNode.id(), interrupt ) ?
sgEdgeStartRealTargetId :
interrupt
);

var edgesWithSubgraphTargetId = stateGraph.edges.edgesByTargetId( subgraphNode.id() );

if( edgesWithSubgraphTargetId.isEmpty() ) {
Expand All @@ -650,12 +681,6 @@ static <State extends AgentState> StateGraphNodesAndEdges<State> process(StateGr

for( var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId ) {

var sgEdgeStartTarget = sgEdgeStart.target();

if( sgEdgeStartTarget.id() == null ) {
throw new GraphStateException( format("the target for node '%s' is null!", subgraphNode.id()) );
}

var newEdge = edgeWithSubgraphTargetId.withSourceAndTargetIdsUpdated( subgraphNode,
Function.identity(),
id -> new EdgeValue<>( (Objects.equals( id, subgraphNode.id() ) ?
Expand Down

0 comments on commit 00679e9

Please sign in to comment.