Skip to content

Commit

Permalink
feat(CompiledGraph.java): enhance subgraph processing in state graph
Browse files Browse the repository at this point in the history
It ensures that interruptions (nodes marked as "before") are correctly redirected to the real target ID after subgraph expansion.

work on #73
  • Loading branch information
bsorrentino committed Feb 11, 2025
1 parent 9e14879 commit 4d61965
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions core/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.toList;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;

Expand Down Expand Up @@ -138,7 +139,7 @@ public Collection<StateSnapshot<State>> getStateHistory( RunnableConfig config )

return saver.list(config).stream()
.map( checkpoint -> StateSnapshot.of( checkpoint, config, stateGraph.getStateFactory() ) )
.collect(Collectors.toList());
.collect(toList());
}


Expand Down Expand Up @@ -622,8 +623,8 @@ public Data<Output> next() {
record StateGraphNodesEdgesAndConfig<State extends AgentState>(
StateGraph.Nodes<State> nodes,
StateGraph.Edges<State> edges,
List<String> interruptsBefore,
List<String> interruptsAfter) {
Set<String> interruptsBefore,
Set<String> interruptsAfter) {

StateGraphNodesEdgesAndConfig(StateGraph<State> stateGraph, CompileConfig config) {
this( stateGraph.nodes,
Expand All @@ -640,18 +641,18 @@ static <State extends AgentState> StateGraphNodesEdgesAndConfig<State> process(S
return new StateGraphNodesEdgesAndConfig<>( stateGraph, config );
}

var result = new StateGraphNodesEdgesAndConfig<>(
new StateGraph.Nodes<>( stateGraph.nodes.exceptSubStateGraphNodes() ),
new StateGraph.Edges<>( stateGraph.edges.elements),
new ArrayList<>(config.interruptsBefore()),
new ArrayList<>(config.interruptsAfter()) );

var interruptsBefore = config.interruptsBefore();
var interruptsAfter = config.interruptsAfter();
var nodes = new StateGraph.Nodes<>( stateGraph.nodes.exceptSubStateGraphNodes() );
var edges = new StateGraph.Edges<>( stateGraph.edges.elements);

for( var subgraphNode : subgraphNodes ) {

var sgWorkflow = subgraphNode.subGraph();

//
// Process START Node
//
var sgEdgeStart = sgWorkflow.edges.edgeBySourceId(START).orElseThrow();

if( sgEdgeStart.isParallel() ) {
Expand All @@ -666,17 +667,17 @@ static <State extends AgentState> StateGraphNodesEdgesAndConfig<State> process(S

var sgEdgeStartRealTargetId = subgraphNode.formatId( sgEdgeStartTarget.id() );

// Process Interruption (Before)
result.interruptsBefore().replaceAll( interrupt ->
// Process Interruption (Before) Subgraph(s)
interruptsBefore = interruptsBefore.stream().map( interrupt ->
Objects.equals( subgraphNode.id(), interrupt ) ?
sgEdgeStartRealTargetId :
interrupt
);
).collect(Collectors.toUnmodifiableSet());

var edgesWithSubgraphTargetId = stateGraph.edges.edgesByTargetId( subgraphNode.id() );

if( edgesWithSubgraphTargetId.isEmpty() ) {
throw new GraphStateException( format("the node '%s' has not present as target in graph!", subgraphNode.id()) );
throw new GraphStateException( format("the node '%s' is not present as target in graph!", subgraphNode.id()) );
}

for( var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId ) {
Expand All @@ -685,46 +686,69 @@ static <State extends AgentState> StateGraphNodesEdgesAndConfig<State> process(S
Function.identity(),
id -> new EdgeValue<>( (Objects.equals( id, subgraphNode.id() ) ?
subgraphNode.formatId( sgEdgeStartTarget.id() ) : id)));
result.edges().elements.remove(edgeWithSubgraphTargetId);
result.edges().elements.add( newEdge );
edges.elements.remove(edgeWithSubgraphTargetId);
edges.elements.add( newEdge );

}

//
// Process END Nodes
//
var sgEdgesEnd = sgWorkflow.edges.edgesByTargetId(END);

var edgeWithSubgraphSourceId = stateGraph.edges.edgeBySourceId( subgraphNode.id() ).orElseThrow();

if( edgeWithSubgraphSourceId.isParallel() ) {
throw new GraphStateException( "subgraph not support routes to parallel branches yet!" );
}

// Process Interruption (After) Subgraph(s)
if( interruptsAfter.contains(subgraphNode.id()) ) {

var exceptionMessage = ( edgeWithSubgraphSourceId.target().id()==null ) ?
"'interruption after' on subgraph is not supported yet!" :
format("'interruption after' on subgraph is not supported yet! consider to use 'interruption before' node: '%s'",
edgeWithSubgraphSourceId.target().id());
throw new GraphStateException( exceptionMessage );

}

sgEdgesEnd.stream()
.map( e -> e.withSourceAndTargetIdsUpdated( subgraphNode,
subgraphNode::formatId,
id -> (Objects.equals(id,END) ?
edgeWithSubgraphSourceId.target() :
new EdgeValue<>(subgraphNode.formatId(id)) ) )
)
.forEach( result.edges().elements::add);
result.edges().elements.remove(edgeWithSubgraphSourceId);
.forEach( edges.elements::add);
edges.elements.remove(edgeWithSubgraphSourceId);


//
// Process edges
//
sgWorkflow.edges.elements.stream()
.filter( e -> !Objects.equals( e.sourceId(),START) )
.filter( e -> !e.anyMatchByTargetId(END) )
.map( e ->
e.withSourceAndTargetIdsUpdated( subgraphNode,
subgraphNode::formatId,
id -> new EdgeValue<>( subgraphNode.formatId(id))) )
.forEach(result.edges().elements::add);
.forEach(edges.elements::add);

//
// Process nodes

//
sgWorkflow.nodes.elements.stream()
.map( n -> n.withIdUpdated( subgraphNode::formatId) )
.forEach(result.nodes().elements::add);
.forEach(nodes.elements::add);

}


return result;
return new StateGraphNodesEdgesAndConfig<>(
nodes,
edges,
interruptsBefore,
interruptsAfter );

}

Expand Down

0 comments on commit 4d61965

Please sign in to comment.