Skip to content

Commit

Permalink
refactor: GraphState to StateGraph
Browse files Browse the repository at this point in the history
make compliant to original LangGraph
  • Loading branch information
bsorrentino committed May 18, 2024
1 parent 6ca0618 commit cfa7c92
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
import dev.langchain4j.model.output.FinishReason;
import lombok.var;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.GraphState;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;

import java.util.*;
import java.util.stream.Collectors;

import static java.util.Collections.unmodifiableMap;
import static org.bsc.langgraph4j.GraphState.END;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;
Expand Down Expand Up @@ -108,7 +107,7 @@ public AsyncGenerator<NodeOutput<State>> execute(ChatLanguageModel chatLanguageM
.tools( toolSpecifications )
.build();

var workflow = new GraphState<>(State::new);
var workflow = new StateGraph<>(State::new);

workflow.setEntryPoint("agent");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.GraphState;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.NodeOutput;

import java.util.Map;
import java.util.concurrent.CompletableFuture;

import static java.util.Optional.ofNullable;
import static org.bsc.langgraph4j.GraphState.END;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

Expand Down Expand Up @@ -108,7 +108,7 @@ private String routeEvaluationResult( State state ) {
@Override
public AsyncGenerator<NodeOutput<State>> execute(Map<String, Object> inputs) throws Exception {

var workflow = new GraphState<>(State::new);
var workflow = new StateGraph<>(State::new);

workflow.addNode( "evaluate_result", this::evaluateResult);
workflow.addNode( "agent_review", this::reviewResult );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.GraphState;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.NodeOutput;

import java.net.URI;
Expand Down Expand Up @@ -169,7 +169,7 @@ public AsyncGenerator<NodeOutput<State>> execute( Map<String, Object> inputs )
.maxTokens(2000)
.build();

var workflow = new GraphState<>(State::new);
var workflow = new StateGraph<>(State::new);

workflow.addNode("agent_describer", node_async( state ->
describeDiagramImage( llmVision, imageUrlOrData, state )) );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;

public class GraphState<State extends AgentState> {
public class StateGraph<State extends AgentState> {
enum Errors {
invalidNodeIdentifier( "END is not a valid node id!"),
invalidEdgeIdentifier( "END is not a valid edge sourceId!"),
Expand Down Expand Up @@ -68,11 +68,11 @@ public class Runnable {

Runnable() {

GraphState.this.nodes.forEach( n ->
StateGraph.this.nodes.forEach(n ->
nodes.put(n.id(), n.action())
);

GraphState.this.edges.forEach( e ->
StateGraph.this.edges.forEach(e ->
edges.put(e.sourceId(), e.target())
);
}
Expand Down Expand Up @@ -165,7 +165,7 @@ public Optional<State> invoke( Map<String,Object> inputs ) throws Exception {

AgentStateFactory<State> stateFactory;

public GraphState( AgentStateFactory<State> stateFactory ) {
public StateGraph(AgentStateFactory<State> stateFactory ) {
this.stateFactory = stateFactory;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import java.util.Map;
import java.util.stream.Collectors;

import static org.bsc.langgraph4j.GraphState.END;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;
Expand All @@ -27,7 +27,7 @@ public static <T> List<Map.Entry<String,T>> sortMap(Map<String,T> map ) {
@Test
void testValidation() throws Exception {

var workflow = new GraphState<>(AgentState::new);
var workflow = new StateGraph<>(AgentState::new);
var exception = assertThrows(GraphStateException.class, workflow::compile);
System.out.println(exception.getMessage());
assertEquals( "missing Entry Point", exception.getMessage());
Expand Down Expand Up @@ -81,7 +81,7 @@ void testValidation() throws Exception {
@Test
public void testRunningOneNode() throws Exception {

var workflow = new GraphState<>(AgentState::new);
var workflow = new StateGraph<>(AgentState::new);
workflow.setEntryPoint("agent_1");

workflow.addNode("agent_1", node_async( state -> {
Expand Down

0 comments on commit cfa7c92

Please sign in to comment.