-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add stage tree and pipeline plotters. also add stage hierarchy test
- Loading branch information
1 parent
1d4313d
commit 0d9b67b
Showing
10 changed files
with
298 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from tools.stage_tree_plotter import initialize, gen_stage_tree | ||
from tools.stage_pipeline_plotter import gen_stage_pipeline | ||
from ProConPy.stage import Stage | ||
from visualCaseGen.cime_interface import CIME_interface | ||
|
||
def test_stage_pipeline(): | ||
"""Confirm that the stage pipeline is a directed acyclic graph.""" | ||
cime = CIME_interface() | ||
initialize(cime) | ||
|
||
# The below call will raise an assertion error if the stage tree is not a forest | ||
gen_stage_tree(Stage.first()) | ||
|
||
# The below call will raise an assertion error if the stage tree is not a directed acyclic graph | ||
gen_stage_pipeline(Stage.first()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from ProConPy.config_var import ConfigVar, cvars | ||
from ProConPy.stage import Stage | ||
from ProConPy.csp_solver import csp | ||
from visualCaseGen.cime_interface import CIME_interface | ||
from visualCaseGen.initialize_configvars import initialize_configvars | ||
from visualCaseGen.initialize_widgets import initialize_widgets | ||
from visualCaseGen.initialize_stages import initialize_stages | ||
from visualCaseGen.specs.options import set_options | ||
from visualCaseGen.specs.relational_constraints import get_relational_constraints | ||
|
||
import networkx as nx | ||
from networkx.drawing.nx_pydot import graphviz_layout | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def initialize(cime): | ||
"""Initializes visualCaseGen""" | ||
ConfigVar.reboot() | ||
Stage.reboot() | ||
initialize_configvars(cime) | ||
initialize_widgets(cime) | ||
initialize_stages(cime) | ||
set_options(cime) | ||
csp.initialize(cvars, get_relational_constraints(cvars), Stage.first()) | ||
|
||
|
||
def gen_stage_pipeline(stage): | ||
"""Generate a directed acyclic graph representing the stage pipeline. | ||
Parameters | ||
---------- | ||
stage : Stage | ||
The starting stage of the pipeline. | ||
Returns | ||
------- | ||
nx.DiGraph | ||
The directed graph representing the stage pipeline.""" | ||
|
||
# Instantiate a directed graph object that will represent the stage pipeline | ||
G = nx.DiGraph() | ||
|
||
# Traverse the entire stage tree using depth-first search | ||
while (next := stage.get_next(full_dfs=True)) is not None: | ||
if stage.children_have_conditions(): | ||
for child in stage._children: | ||
G.add_edge(stage, child) | ||
G.add_edge(child, child._children[0]) | ||
else: | ||
# The actual next stage that would be visited during runtime | ||
runtime_next = stage.get_next(full_dfs=False) | ||
if runtime_next: | ||
G.add_edge(stage, runtime_next) | ||
stage = next | ||
|
||
assert nx.is_directed_acyclic_graph( | ||
G | ||
), "The stage tree is not a directed acyclic graph." | ||
return G | ||
|
||
|
||
def plot_stage_pipeline(G, output_file=None): | ||
"""Plot the stage pipeline.""" | ||
plt.figure(figsize=(6, 12)) | ||
# pos = graphviz_layout(G, prog="sfdp") | ||
pos = graphviz_layout(G, prog="dot") | ||
nx.draw( | ||
G, | ||
pos, | ||
with_labels=False, | ||
edge_color="gray", | ||
node_color="powderblue", | ||
font_size=6, | ||
) | ||
text = nx.draw_networkx_labels(G, pos) | ||
for _, t in text.items(): | ||
t.set_rotation(-15) | ||
t.set_verticalalignment("center") | ||
|
||
if output_file: | ||
plt.savefig(output_file) | ||
else: | ||
plt.show() | ||
|
||
|
||
def generate_path_animation(G, start, end, output_file): | ||
"""Generate an animation of the possible paths from start to end in the stage pipeline. | ||
This will generate a single png image for each path from start to end, highlighting the edges | ||
in the path in red. These png files can be combined into a gif by running: | ||
$ convert -delay 100 -loop 0 stage_pipeline_*.png stage_pipeline.gif | ||
Parameters | ||
---------- | ||
G : nx.DiGraph | ||
The directed graph representing the stage pipeline. | ||
start : Stage | ||
The starting stage. | ||
end : Stage | ||
The ending stage. | ||
""" | ||
# Find all possible paths from start to end | ||
all_paths = nx.all_simple_paths(G, start, end) | ||
|
||
# Iterate over each path | ||
for p, path in enumerate(all_paths): | ||
# Create a copy of the graph to highlight the current path | ||
highlighted_G = G.copy() | ||
|
||
for edge in highlighted_G.edges(): | ||
highlighted_G[edge[0]][edge[1]]["color"] = "gray" | ||
highlighted_G[edge[0]][edge[1]]["penwidth"] = 1.0 | ||
|
||
# Highlight the edges in the current path | ||
for i in range(len(path) - 1): | ||
highlighted_G[path[i]][path[i + 1]]["color"] = "red" | ||
highlighted_G[path[i]][path[i + 1]]["penwidth"] = 2.0 | ||
|
||
colors = [highlighted_G[u][v]["color"] for u, v in highlighted_G.edges()] | ||
weights = [highlighted_G[u][v]["penwidth"] for u, v in highlighted_G.edges()] | ||
|
||
# Generate the plot for the current path | ||
plt.figure(figsize=(6, 12)) | ||
pos = graphviz_layout(highlighted_G, prog="dot") | ||
nx.draw( | ||
highlighted_G, | ||
pos, | ||
edge_color=colors, | ||
with_labels=False, | ||
node_color="powderblue", | ||
font_size=6, | ||
width=weights, | ||
) | ||
text = nx.draw_networkx_labels(highlighted_G, pos) | ||
for _, t in text.items(): | ||
t.set_rotation(-15) | ||
t.set_verticalalignment("center") | ||
|
||
# Save the plot as an image | ||
image_file = f"{output_file}_{p}.png" | ||
print(f"Saving image to {image_file}") | ||
plt.savefig(image_file) | ||
plt.close() | ||
|
||
|
||
def main(): | ||
initialize(CIME_interface()) | ||
G = gen_stage_pipeline(Stage.first()) | ||
plot_stage_pipeline(G) | ||
# generate_path_animation(G, Stage.first(), Stage._top_level[-1], "stage_pipeline") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from ProConPy.config_var import ConfigVar, cvars | ||
from ProConPy.stage import Stage | ||
from ProConPy.csp_solver import csp | ||
from visualCaseGen.cime_interface import CIME_interface | ||
from visualCaseGen.initialize_configvars import initialize_configvars | ||
from visualCaseGen.initialize_widgets import initialize_widgets | ||
from visualCaseGen.initialize_stages import initialize_stages | ||
from visualCaseGen.specs.options import set_options | ||
from visualCaseGen.specs.relational_constraints import get_relational_constraints | ||
|
||
import networkx as nx | ||
from networkx.drawing.nx_pydot import graphviz_layout | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def initialize(cime): | ||
"""Initializes visualCaseGen""" | ||
ConfigVar.reboot() | ||
Stage.reboot() | ||
initialize_configvars(cime) | ||
initialize_widgets(cime) | ||
initialize_stages(cime) | ||
set_options(cime) | ||
csp.initialize(cvars, get_relational_constraints(cvars), Stage.first()) | ||
|
||
|
||
def gen_stage_tree(stage): | ||
"""Generate the stage tree by traversing all stages using depth-first search. | ||
Parameters | ||
---------- | ||
stage : Stage | ||
The initial stage to start the traversal from. | ||
Returns | ||
------- | ||
G : nx.DiGraph | ||
The directed graph representing the stage tree. | ||
""" | ||
|
||
# Instantiate a graph object that will represent the stage tree | ||
G = nx.Graph() | ||
|
||
while (next := stage.get_next(full_dfs=True)) is not None: | ||
if stage._parent is not None and stage._parent.has_condition(): | ||
G.add_edge(stage._parent, stage) | ||
for child in stage._children: | ||
G.add_edge(stage, child) | ||
stage = next | ||
|
||
assert nx.is_forest(G), "The stage tree is not a tree." | ||
|
||
return G | ||
|
||
def plot_stage_tree(stage): | ||
"""Plot the stage tree.""" | ||
|
||
# Traverse the stage tree using depth-first search | ||
G = gen_stage_tree(stage) | ||
|
||
# Draw the graph | ||
try: | ||
from networkx.drawing.nx_pydot import graphviz_layout | ||
|
||
pos = graphviz_layout(G, prog="dot") | ||
except ImportError: | ||
print( | ||
"WARNING: PyGraphviz is not installed. Drawing the graph using spring layout." | ||
) | ||
pos = nx.spring_layout(G) | ||
nx.draw( | ||
G, | ||
pos, | ||
with_labels=False, | ||
edge_color="gray", | ||
node_color="powderblue", | ||
font_size=9, | ||
) | ||
text = nx.draw_networkx_labels(G, pos) | ||
for _, t in text.items(): | ||
t.set_rotation(20) | ||
t.set_verticalalignment("center") | ||
|
||
# Set the color of Guard nodes | ||
guard_nodes = [node for node in G.nodes if node.has_condition()] | ||
nx.draw_networkx_nodes(G, pos, nodelist=guard_nodes, node_color="wheat") | ||
|
||
plt.show() | ||
|
||
|
||
def main(): | ||
initialize(CIME_interface()) | ||
plot_stage_tree(Stage.first()) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.