Skip to content

Commit

Permalink
feat: toggle conditional-edge representation
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Jul 19, 2024
1 parent fbd73f1 commit 4e55eda
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 59 deletions.
19 changes: 17 additions & 2 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,21 @@ public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return result.reduce((a, b) -> b).map( NodeOutput::state);
}

/**
* Generates a drawable graph representation of the state graph.
*
* @param type the type of graph representation to generate
* @param title the title of the graph
* @param printConditionalEdges whether to print conditional edges
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph( GraphRepresentation.Type type, String title, boolean printConditionalEdges ) {

String content = type.generator.generate( this, title, printConditionalEdges);

return new GraphRepresentation( type, content );
}

/**
* Generates a drawable graph representation of the state graph.
*
Expand All @@ -198,7 +213,7 @@ public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
*/
public GraphRepresentation getGraph( GraphRepresentation.Type type, String title ) {

String content = type.generator.generate( this,title);
String content = type.generator.generate( this, title, true);

return new GraphRepresentation( type, content );
}
Expand All @@ -210,7 +225,7 @@ public GraphRepresentation getGraph( GraphRepresentation.Type type, String title
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph( GraphRepresentation.Type type ) {
return getGraph(type, "Graph Diagram");
return getGraph(type, "Graph Diagram", true);
}

}
27 changes: 18 additions & 9 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/DiagramGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public abstract class DiagramGenerator {
protected abstract void declareConditionalStart( StringBuilder sb, String name ) ;
protected abstract void declareNode( StringBuilder sb, String name ) ;
protected abstract void declareConditionalEdge( StringBuilder sb, int ordinal ) ;
protected abstract StringBuilder commentLine( StringBuilder sb, boolean yesOrNo );

public final <State extends AgentState> String generate( CompiledGraph<State> compiledGraph,String title ) {
public final <State extends AgentState> String generate( CompiledGraph<State> compiledGraph, String title, boolean printConditionalEdge ) {
StringBuilder sb = new StringBuilder();

appendHeader( sb, title );
Expand All @@ -32,7 +33,7 @@ public final <State extends AgentState> String generate( CompiledGraph<State> co
compiledGraph.getEdges().forEach( (k, v) -> {
if( v.value() != null ) {
conditionalEdgeCount[0] += 1;
declareConditionalEdge( sb, conditionalEdgeCount[0] );
declareConditionalEdge( commentLine(sb, !printConditionalEdge), conditionalEdgeCount[0] );
}
});

Expand All @@ -44,20 +45,20 @@ public final <State extends AgentState> String generate( CompiledGraph<State> co
else if( entryPoint.value() != null ) {
String conditionName = "startcondition";
declareConditionalStart( sb, conditionName );
edgeCondition( sb, entryPoint.value(), "start", conditionName) ;
edgeCondition( sb, entryPoint.value(), "start", conditionName, printConditionalEdge) ;
}

conditionalEdgeCount[0] = 0; // reset

compiledGraph.getEdges().forEach( (k,v) -> {
if( v.id() != null ) {
call( sb, k, v.id() );
return;
}
else if( v.value() != null ) {
conditionalEdgeCount[0] += 1;
String conditionName = format("condition%d", conditionalEdgeCount[0]);
edgeCondition( sb, v.value(), k, conditionName );

edgeCondition( sb, v.value(), k, conditionName, printConditionalEdge );

}
});
Expand All @@ -69,15 +70,23 @@ else if( v.value() != null ) {
return sb.toString();

}
private <State extends AgentState> void edgeCondition(StringBuilder sb, EdgeCondition<State> condition, String key, String conditionName ) {
call( sb, key, conditionName);
private <State extends AgentState> void edgeCondition(StringBuilder sb,
EdgeCondition<State> condition,
String k,
String conditionName,
boolean printConditionalEdge) {
call( commentLine(sb, !printConditionalEdge), k, conditionName);

condition.mappings().forEach( (cond, to) -> {
if( to.equals(StateGraph.END) ) {
finish( sb, conditionName, cond );

finish( commentLine(sb, !printConditionalEdge), conditionName, cond );
finish( commentLine(sb, printConditionalEdge), k, cond );

}
else {
call( sb, conditionName, to, cond );
call( commentLine(sb, !printConditionalEdge), conditionName, to, cond );
call( commentLine(sb, printConditionalEdge), k, to, cond );
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ protected void declareConditionalEdge(StringBuilder sb, int ordinal) {
sb.append( format("\tcondition%d{\"check state\"}\n", ordinal) );
}

@Override
@Override
protected StringBuilder commentLine(StringBuilder sb, boolean yesOrNo) {
return (yesOrNo) ? sb.append( "\t%%" ) : sb;
}

@Override
protected void start(StringBuilder sb, String entryPoint) {
call( sb, "start", entryPoint );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,10 @@ protected void declareConditionalEdge( StringBuilder sb, int ordinal ) {
sb.append( format("hexagon \"check state\" as condition%d<<Condition>>\n", ordinal ) );
}

@Override
protected StringBuilder commentLine(StringBuilder sb, boolean yesOrNo) {
return (yesOrNo) ? sb.append( "'" ) : sb;
}


}
108 changes: 62 additions & 46 deletions core-jdk8/src/test/java/org/bsc/langgraph4j/GraphTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void testSimpleGraph() throws Exception {
"\"agent_2\" -down-> stop\n" +
"@enduml\n", result.getContent() );

System.out.println( result.getContent() );
// System.out.println( result.getContent() );
}

@Test
Expand All @@ -86,29 +86,33 @@ public void testCorrectionProcessGraph() throws Exception {
assertEquals( GraphRepresentation.Type.PLANTUML, result.getType() );

assertEquals( "@startuml unnamed.puml\n" +
"skinparam usecaseFontSize 14\n" +
"skinparam usecaseStereotypeFontSize 12\n" +
"skinparam hexagonFontSize 14\n" +
"skinparam hexagonStereotypeFontSize 12\n" +
"title \"Graph Diagram\"\n" +
"footer\n" +
"\n" +
"powered by langgraph4j\n" +
"end footer\n" +
"circle start<<input>>\n" +
"circle stop\n" +
"usecase \"evaluate_result\"<<Node>>\n" +
"usecase \"agent_review\"<<Node>>\n" +
"hexagon \"check state\" as condition1<<Condition>>\n" +
"start -down-> \"evaluate_result\"\n" +
"\"agent_review\" -down-> \"evaluate_result\"\n" +
"\"evaluate_result\" -down-> \"condition1\"\n" +
"\"condition1\" --> \"agent_review\": \"ERROR\"\n" +
"\"condition1\" -down-> stop: \"UNKNOWN\"\n" +
"\"condition1\" -down-> stop: \"OK\"\n" +
"@enduml\n", result.getContent() );

System.out.println( result.getContent() );
"skinparam usecaseFontSize 14\n" +
"skinparam usecaseStereotypeFontSize 12\n" +
"skinparam hexagonFontSize 14\n" +
"skinparam hexagonStereotypeFontSize 12\n" +
"title \"Graph Diagram\"\n" +
"footer\n" +
"\n" +
"powered by langgraph4j\n" +
"end footer\n" +
"circle start<<input>>\n" +
"circle stop\n" +
"usecase \"evaluate_result\"<<Node>>\n" +
"usecase \"agent_review\"<<Node>>\n" +
"hexagon \"check state\" as condition1<<Condition>>\n" +
"start -down-> \"evaluate_result\"\n" +
"\"agent_review\" -down-> \"evaluate_result\"\n" +
"\"evaluate_result\" -down-> \"condition1\"\n" +
"\"condition1\" --> \"agent_review\": \"ERROR\"\n" +
"'\"evaluate_result\" --> \"agent_review\": \"ERROR\"\n" +
"\"condition1\" -down-> stop: \"UNKNOWN\"\n" +
"'\"evaluate_result\" -down-> stop: \"UNKNOWN\"\n" +
"\"condition1\" -down-> stop: \"OK\"\n" +
"'\"evaluate_result\" -down-> stop: \"OK\"\n" +
"@enduml\n",
result.getContent() );

// System.out.println( result.getContent() );


}
Expand Down Expand Up @@ -152,11 +156,14 @@ public void GenerateAgentExecutorGraph() throws Exception {
"start -down-> \"agent\"\n" +
"\"agent\" -down-> \"condition1\"\n" +
"\"condition1\" --> \"action\": \"continue\"\n" +
"'\"agent\" --> \"action\": \"continue\"\n" +
"\"condition1\" -down-> stop: \"end\"\n" +
"'\"agent\" -down-> stop: \"end\"\n" +
"\"action\" -down-> \"agent\"\n" +
"@enduml\n", result.getContent() );
"@enduml\n",
result.getContent() );

System.out.println( result.getContent() );
// System.out.println( result.getContent() );
}

@Test
Expand All @@ -180,7 +187,7 @@ public void GenerateImageToDiagramGraph() throws Exception {

var app = workflow.compile();

var result = app.getGraph( GraphRepresentation.Type.PLANTUML );
var result = app.getGraph( GraphRepresentation.Type.PLANTUML);
assertEquals( GraphRepresentation.Type.PLANTUML, result.getType() );

assertEquals( "@startuml unnamed.puml\n" +
Expand All @@ -203,31 +210,40 @@ public void GenerateImageToDiagramGraph() throws Exception {
"start -down-> \"agent_describer\"\n" +
"\"agent_describer\" -down-> \"condition1\"\n" +
"\"condition1\" --> \"agent_sequence_plantuml\": \"sequence\"\n" +
"'\"agent_describer\" --> \"agent_sequence_plantuml\": \"sequence\"\n" +
"\"condition1\" --> \"agent_generic_plantuml\": \"generic\"\n" +
"'\"agent_describer\" --> \"agent_generic_plantuml\": \"generic\"\n" +
"\"agent_sequence_plantuml\" -down-> \"evaluate_result\"\n" +
"\"agent_generic_plantuml\" -down-> \"evaluate_result\"\n" +
"\"evaluate_result\" -down-> stop\n" +
"@enduml\n", result.getContent() );
"@enduml\n",
result.getContent() );

result = app.getGraph( GraphRepresentation.Type.MERMAID );
result = app.getGraph( GraphRepresentation.Type.MERMAID, "Graph Diagram", false );
assertEquals( GraphRepresentation.Type.MERMAID, result.getType() );

// System.out.println( result.getContent() );

assertEquals( "---\n" +
"title: Graph Diagram\n" +
"---\n" +
"flowchart TD\n" +
"\tstart((start))\n" +
"\tstop((stop))\n" +
"\tagent_describer(\"agent_describer\")\n" +
"\tagent_sequence_plantuml(\"agent_sequence_plantuml\")\n" +
"\tagent_generic_plantuml(\"agent_generic_plantuml\")\n" +
"\tevaluate_result(\"evaluate_result\")\n" +
"\tcondition1{\"check state\"}\n" +
"\tstart:::start --> agent_describer:::agent_describer\n" +
"\tagent_describer:::agent_describer --> condition1:::condition1\n" +
"\tcondition1:::condition1 -->|sequence| agent_sequence_plantuml:::agent_sequence_plantuml\n" +
"\tcondition1:::condition1 -->|generic| agent_generic_plantuml:::agent_generic_plantuml\n" +
"\tagent_sequence_plantuml:::agent_sequence_plantuml --> evaluate_result:::evaluate_result\n" +
"\tagent_generic_plantuml:::agent_generic_plantuml --> evaluate_result:::evaluate_result\n" +
"\tevaluate_result:::evaluate_result --> stop:::stop\n", result.getContent() );
"title: Graph Diagram\n" +
"---\n" +
"flowchart TD\n" +
"\tstart((start))\n" +
"\tstop((stop))\n" +
"\tagent_describer(\"agent_describer\")\n" +
"\tagent_sequence_plantuml(\"agent_sequence_plantuml\")\n" +
"\tagent_generic_plantuml(\"agent_generic_plantuml\")\n" +
"\tevaluate_result(\"evaluate_result\")\n" +
"\t%%\tcondition1{\"check state\"}\n" +
"\tstart:::start --> agent_describer:::agent_describer\n" +
"\t%%\tagent_describer:::agent_describer --> condition1:::condition1\n" +
"\t%%\tcondition1:::condition1 -->|sequence| agent_sequence_plantuml:::agent_sequence_plantuml\n" +
"\tagent_describer:::agent_describer -->|sequence| agent_sequence_plantuml:::agent_sequence_plantuml\n" +
"\t%%\tcondition1:::condition1 -->|generic| agent_generic_plantuml:::agent_generic_plantuml\n" +
"\tagent_describer:::agent_describer -->|generic| agent_generic_plantuml:::agent_generic_plantuml\n" +
"\tagent_sequence_plantuml:::agent_sequence_plantuml --> evaluate_result:::evaluate_result\n" +
"\tagent_generic_plantuml:::agent_generic_plantuml --> evaluate_result:::evaluate_result\n" +
"\tevaluate_result:::evaluate_result --> stop:::stop\n",
result.getContent() );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");

GraphRepresentation graph = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID);
GraphRepresentation graph = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID, initData.title(), false);

final Result result = new Result(graph, initData);
String resultJson = objectMapper.writeValueAsString(result);
Expand Down

0 comments on commit 4e55eda

Please sign in to comment.