Skip to content

Commit

Permalink
add stage tree and pipeline plotters. also add stage hierarchy test
Browse files Browse the repository at this point in the history
  • Loading branch information
alperaltuntas committed Jun 6, 2024
1 parent 1d4313d commit 0d9b67b
Show file tree
Hide file tree
Showing 10 changed files with 298 additions and 18 deletions.
16 changes: 15 additions & 1 deletion ProConPy/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Node:
# Top level nodes, i.e., nodes that have no parent
_top_level = []

# Set of titles of all the nodes in the stage tree
_titles = set()

def __init__(self, title, parent=None, condition=None):
"""Initialize a node.
Expand All @@ -36,6 +39,10 @@ def __init__(self, title, parent=None, condition=None):
The logical condition that must be satisfied for the node to be enabled.
"""

if title in Node._titles:
raise ValueError(f"The title {title} is already used.")
Node._titles.add(title)

self._title = title
self._children = []
self._parent = parent
Expand All @@ -62,6 +69,13 @@ def __init__(self, title, parent=None, condition=None):
def __str__(self):
return self._title

@classmethod
def reboot(cls):
"""Class method to reset the Node class so that it can be re-initialized.
This is useful for testing purposes and need not be utilized in production."""
cls._top_level = []
cls._titles.clear()

@classmethod
def first(cls):
"""Class method that returns the first top-level node"""
Expand Down Expand Up @@ -265,7 +279,7 @@ def __init__(
def reboot(cls):
"""Class method to reset the Stage class so that it can be re-initialized.
This is useful for testing purposes and should not be utilized in production."""
Node._top_level = []
Node.reboot()
cls._completed_stages = []
cls._active_stage = None
# todo: remove all instances of Stage
Expand Down
2 changes: 1 addition & 1 deletion tests/2_integration/test_constraint_violation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_constraint_violation_detection():
# Grid
assert Stage.active().title.startswith('2. Grid')
cvars['GRID_MODE'].value = 'Custom'
assert Stage.active().title.startswith('Custom Grid Generator')
assert Stage.active().title.startswith('Custom Grid')

custom_grid_path = Path(temp_dir) / "custom_grid"
cvars['CUSTOM_GRID_PATH'].value = str(custom_grid_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/2_integration/test_custom_compset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def configure_custom_compset():
assert not Stage.first().enabled

# The next stge is Custom Component Set, whose first child is Model Time Period
assert Stage.active().title.startswith('Model Time Period')
assert Stage.active().title.startswith('Time Period')
cvars['INITTIME'].value = '2000'

# Set components
Expand Down Expand Up @@ -92,7 +92,7 @@ def configure_custom_compset():
# COMP_?_OPTIONS variables have been set, so the next stage is Grid:
assert Stage.active().title.startswith('2. Grid')
cvars['GRID_MODE'].value = 'Standard'
assert Stage.active().title.startswith('Standard Grid Selector')
assert Stage.active().title.startswith('Standard Grid')

# change of mind, revert and pick new components
Stage.active().revert()
Expand Down
6 changes: 3 additions & 3 deletions tests/3_system/test_f2000_custom_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def construct_custom_f2000_compset():
assert Stage.first().enabled
cvars["COMPSET_MODE"].value = "Custom"

assert Stage.active().title.startswith("Model Time Period")
assert Stage.active().title.startswith("Time Period")
cvars["INITTIME"].value = "2000"

assert Stage.active().title.startswith("Components")
Expand Down Expand Up @@ -222,7 +222,7 @@ def construct_custom_res_from_modified_clm_grid(cime):
# click the "Run Mesh Mask Modifier" button
mesh_mask_modifier_launcher._on_launch_clicked(b=None)

assert Stage.active().title.startswith("Surface Data Modifier")
assert Stage.active().title.startswith("fsurdat")
assert cvars["INPUT_FSURDAT"].value is not None, "INPUT_FSURDAT should be auto-filled"
assert cvars["FSURDAT_AREA_SPEC"].value.startswith("mask_file:"), "FSURDAT_AREA_SPEC should be auto-filled"
cvars["FSURDAT_IDEALIZED"].value = "True"
Expand Down Expand Up @@ -338,7 +338,7 @@ def construct_custom_res_from_new_mom6_grid_modified_clm_grid(cime):
assert Stage.active().title.startswith("Base Land Grid")
cvars["CUSTOM_LND_GRID"].value = "4x5"

assert Stage.active().title.startswith("Surface Data Modifier")
assert Stage.active().title.startswith("fsurdat")
assert cvars["INPUT_FSURDAT"].value is not None, "INPUT_FSURDAT should be auto-filled"
cvars["FSURDAT_AREA_SPEC"].value = "mask_file:/glade/work/altuntas/cesm.input/vcg/mask_fillIO_f45.nc"
cvars["FSURDAT_IDEALIZED"].value = "True"
Expand Down
4 changes: 2 additions & 2 deletions tests/3_system/test_fhist_custom_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def construct_custom_fhist_compset():
assert Stage.first().enabled
cvars["COMPSET_MODE"].value = "Custom"

assert Stage.active().title.startswith("Model Time Period")
assert Stage.active().title.startswith("Time Period")
cvars["INITTIME"].value = "HIST"

assert Stage.active().title.startswith("Components")
Expand Down Expand Up @@ -192,7 +192,7 @@ def construct_custom_res_from_modified_clm_grid(cime):
# click the "Run Mesh Mask Modifier" button
mesh_mask_modifier_launcher._on_launch_clicked(b=None)

assert Stage.active().title.startswith("Surface Data Modifier")
assert Stage.active().title.startswith("fsurdat")
assert cvars["INPUT_FSURDAT"].value is not None, "INPUT_FSURDAT should be auto-filled"
assert cvars["FSURDAT_AREA_SPEC"].value.startswith("mask_file:"), "FSURDAT_AREA_SPEC should be auto-filled"
cvars["FSURDAT_IDEALIZED"].value = "True"
Expand Down
16 changes: 16 additions & 0 deletions tests/4_static/test_stage_hierarchy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from tools.stage_tree_plotter import initialize, gen_stage_tree
from tools.stage_pipeline_plotter import gen_stage_pipeline
from ProConPy.stage import Stage
from visualCaseGen.cime_interface import CIME_interface

def test_stage_pipeline():
"""Confirm that the stage pipeline is a directed acyclic graph."""
cime = CIME_interface()
initialize(cime)

# The below call will raise an assertion error if the stage tree is not a forest
gen_stage_tree(Stage.first())

# The below call will raise an assertion error if the stage tree is not a directed acyclic graph
gen_stage_pipeline(Stage.first())

153 changes: 153 additions & 0 deletions tools/stage_pipeline_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from ProConPy.config_var import ConfigVar, cvars
from ProConPy.stage import Stage
from ProConPy.csp_solver import csp
from visualCaseGen.cime_interface import CIME_interface
from visualCaseGen.initialize_configvars import initialize_configvars
from visualCaseGen.initialize_widgets import initialize_widgets
from visualCaseGen.initialize_stages import initialize_stages
from visualCaseGen.specs.options import set_options
from visualCaseGen.specs.relational_constraints import get_relational_constraints

import networkx as nx
from networkx.drawing.nx_pydot import graphviz_layout
import matplotlib.pyplot as plt


def initialize(cime):
"""Initializes visualCaseGen"""
ConfigVar.reboot()
Stage.reboot()
initialize_configvars(cime)
initialize_widgets(cime)
initialize_stages(cime)
set_options(cime)
csp.initialize(cvars, get_relational_constraints(cvars), Stage.first())


def gen_stage_pipeline(stage):
"""Generate a directed acyclic graph representing the stage pipeline.
Parameters
----------
stage : Stage
The starting stage of the pipeline.
Returns
-------
nx.DiGraph
The directed graph representing the stage pipeline."""

# Instantiate a directed graph object that will represent the stage pipeline
G = nx.DiGraph()

# Traverse the entire stage tree using depth-first search
while (next := stage.get_next(full_dfs=True)) is not None:
if stage.children_have_conditions():
for child in stage._children:
G.add_edge(stage, child)
G.add_edge(child, child._children[0])
else:
# The actual next stage that would be visited during runtime
runtime_next = stage.get_next(full_dfs=False)
if runtime_next:
G.add_edge(stage, runtime_next)
stage = next

assert nx.is_directed_acyclic_graph(
G
), "The stage tree is not a directed acyclic graph."
return G


def plot_stage_pipeline(G, output_file=None):
"""Plot the stage pipeline."""
plt.figure(figsize=(6, 12))
# pos = graphviz_layout(G, prog="sfdp")
pos = graphviz_layout(G, prog="dot")
nx.draw(
G,
pos,
with_labels=False,
edge_color="gray",
node_color="powderblue",
font_size=6,
)
text = nx.draw_networkx_labels(G, pos)
for _, t in text.items():
t.set_rotation(-15)
t.set_verticalalignment("center")

if output_file:
plt.savefig(output_file)
else:
plt.show()


def generate_path_animation(G, start, end, output_file):
"""Generate an animation of the possible paths from start to end in the stage pipeline.
This will generate a single png image for each path from start to end, highlighting the edges
in the path in red. These png files can be combined into a gif by running:
$ convert -delay 100 -loop 0 stage_pipeline_*.png stage_pipeline.gif
Parameters
----------
G : nx.DiGraph
The directed graph representing the stage pipeline.
start : Stage
The starting stage.
end : Stage
The ending stage.
"""
# Find all possible paths from start to end
all_paths = nx.all_simple_paths(G, start, end)

# Iterate over each path
for p, path in enumerate(all_paths):
# Create a copy of the graph to highlight the current path
highlighted_G = G.copy()

for edge in highlighted_G.edges():
highlighted_G[edge[0]][edge[1]]["color"] = "gray"
highlighted_G[edge[0]][edge[1]]["penwidth"] = 1.0

# Highlight the edges in the current path
for i in range(len(path) - 1):
highlighted_G[path[i]][path[i + 1]]["color"] = "red"
highlighted_G[path[i]][path[i + 1]]["penwidth"] = 2.0

colors = [highlighted_G[u][v]["color"] for u, v in highlighted_G.edges()]
weights = [highlighted_G[u][v]["penwidth"] for u, v in highlighted_G.edges()]

# Generate the plot for the current path
plt.figure(figsize=(6, 12))
pos = graphviz_layout(highlighted_G, prog="dot")
nx.draw(
highlighted_G,
pos,
edge_color=colors,
with_labels=False,
node_color="powderblue",
font_size=6,
width=weights,
)
text = nx.draw_networkx_labels(highlighted_G, pos)
for _, t in text.items():
t.set_rotation(-15)
t.set_verticalalignment("center")

# Save the plot as an image
image_file = f"{output_file}_{p}.png"
print(f"Saving image to {image_file}")
plt.savefig(image_file)
plt.close()


def main():
initialize(CIME_interface())
G = gen_stage_pipeline(Stage.first())
plot_stage_pipeline(G)
# generate_path_animation(G, Stage.first(), Stage._top_level[-1], "stage_pipeline")


if __name__ == "__main__":
main()
97 changes: 97 additions & 0 deletions tools/stage_tree_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from ProConPy.config_var import ConfigVar, cvars
from ProConPy.stage import Stage
from ProConPy.csp_solver import csp
from visualCaseGen.cime_interface import CIME_interface
from visualCaseGen.initialize_configvars import initialize_configvars
from visualCaseGen.initialize_widgets import initialize_widgets
from visualCaseGen.initialize_stages import initialize_stages
from visualCaseGen.specs.options import set_options
from visualCaseGen.specs.relational_constraints import get_relational_constraints

import networkx as nx
from networkx.drawing.nx_pydot import graphviz_layout
import matplotlib.pyplot as plt


def initialize(cime):
"""Initializes visualCaseGen"""
ConfigVar.reboot()
Stage.reboot()
initialize_configvars(cime)
initialize_widgets(cime)
initialize_stages(cime)
set_options(cime)
csp.initialize(cvars, get_relational_constraints(cvars), Stage.first())


def gen_stage_tree(stage):
"""Generate the stage tree by traversing all stages using depth-first search.
Parameters
----------
stage : Stage
The initial stage to start the traversal from.
Returns
-------
G : nx.DiGraph
The directed graph representing the stage tree.
"""

# Instantiate a graph object that will represent the stage tree
G = nx.Graph()

while (next := stage.get_next(full_dfs=True)) is not None:
if stage._parent is not None and stage._parent.has_condition():
G.add_edge(stage._parent, stage)
for child in stage._children:
G.add_edge(stage, child)
stage = next

assert nx.is_forest(G), "The stage tree is not a tree."

return G

def plot_stage_tree(stage):
"""Plot the stage tree."""

# Traverse the stage tree using depth-first search
G = gen_stage_tree(stage)

# Draw the graph
try:
from networkx.drawing.nx_pydot import graphviz_layout

pos = graphviz_layout(G, prog="dot")
except ImportError:
print(
"WARNING: PyGraphviz is not installed. Drawing the graph using spring layout."
)
pos = nx.spring_layout(G)
nx.draw(
G,
pos,
with_labels=False,
edge_color="gray",
node_color="powderblue",
font_size=9,
)
text = nx.draw_networkx_labels(G, pos)
for _, t in text.items():
t.set_rotation(20)
t.set_verticalalignment("center")

# Set the color of Guard nodes
guard_nodes = [node for node in G.nodes if node.has_condition()]
nx.draw_networkx_nodes(G, pos, nodelist=guard_nodes, node_color="wheat")

plt.show()


def main():
initialize(CIME_interface())
plot_stage_tree(Stage.first())


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion visualCaseGen/stages/compset_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def initialize_compset_stages(cime):
)

stg_inittime = Stage(
title="Model Time Period:",
title="Time Period",
description="Select the initialization time for the experiment. This "
"influences the initial conditions and forcings used in the simulation. 1850 "
"corresponds to pre-industrial conditions and is appropriate for fixed-time-period "
Expand Down
Loading

0 comments on commit 0d9b67b

Please sign in to comment.