Skip to content

Commit

Permalink
further related refinements in csp_solver
Browse files Browse the repository at this point in the history
  • Loading branch information
alperaltuntas committed Jun 7, 2024
1 parent 15e438e commit 140b7ac
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 40 deletions.
113 changes: 74 additions & 39 deletions ProConPy/csp_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def reboot(self):
self._options_assertions = {}
self._past_options_assertions = []
self._tlock = TraversalLock()
self._checked_assignment = None # A record of the current assignment being processed. This is used
# as a hand-shake mechanism between check_assignment and register_assignment.
self._checked_assignment = None
# ^ A record of the current assignment being processed. This is used
# as a hand-shake mechanism between check_assignment and register_assignment.

@owh.out.capture()
def proceed(self):
Expand Down Expand Up @@ -61,20 +62,19 @@ def revert(self):
self._options_assertions = self._past_options_assertions.pop()
self._refresh_solver()


def _refresh_solver(self):
"""Reset the solver and (re-)apply the relational constraints, the past assignment
assertions, and the past options assertions. This method is called when the user wants
to proceed/revert to a following/previous stage. Resetting the solver turned out to
be more efficient than the initial approach of using push/pop to manage the solver."""
be more efficient than the initial approach of using push/pop to manage the solver.
"""
self._solver.reset()
self._solver.add([asrt for asrt, _ in self._relational_constraints.items()])
for scope in self._past_assignment_assertions:
self._solver.add([asrt for _, asrt in scope.items()])
for scope in self._past_options_assertions:
self._solver.add([asrt for _, asrt in scope.items()])


def initialize(self, cvars, relational_constraints, first_stage):
"""Initialize the CSP solver with relational constraints. The relational constraints are
the constraints that are derived from the relationships between the variables. The
Expand Down Expand Up @@ -117,30 +117,34 @@ def initialize(self, cvars, relational_constraints, first_stage):
logger.info("CspSolver initialized.")

def _determine_variable_ranks(self, stage, cvars):
"""Determine the ranks of the variables. The ranks are determined by checking the
consistency of the variable precedence. The precedence of the variables is determined by
the order in which the variables are assigned in the stage tree. The lower the rank, the
higher the precedence."""

# Solver to check if a consistent ranking of variables is possible
s = Solver()

# Instantiate temporary rank variables for each config variable to determine their ranks
[Int(f'{var}_rank') for var in cvars]
[Int(f"{var}_rank") for var in cvars]

# The maximum rank
max_rank = Int('max_rank')
max_rank = Int("max_rank")

while stage is not None:

varlist = stage._varlist
assert len(varlist) > 0, "Stage has no variables."

curr_rank = Int(f'{varlist[0]}_rank')
curr_rank = Int(f"{varlist[0]}_rank")

# All ranks must be nonnegative and less than or equal to the maximum rank
s.add([0 <= curr_rank, curr_rank <= max_rank])

# All stage vars must have the same rank
for var in varlist[1:]:
s.add(curr_rank == Int(f'{var}_rank'))
s.add(curr_rank == Int(f"{var}_rank"))

# The next stage in stage tree (via full DFS traversal)
dfs_next_stage = stage.get_next(full_dfs=True)
if dfs_next_stage is None:
Expand All @@ -151,16 +155,24 @@ def _determine_variable_ranks(self, stage, cvars):
dfs_next_stage = dfs_next_stage.get_next(full_dfs=True)
# Now, process the guard variables.
if isinstance(condition, BoolRef):
guard_vars = [cvars[var.sexpr()] for var in z3util.get_vars(condition)]
guard_vars = [
cvars[var.sexpr()] for var in z3util.get_vars(condition)
]
for guard_var in guard_vars:
# Mark guard variables
guard_var.is_guard_var = True
# All guard variables must have a lower rank than the variables in the next stage:
s.add(Int(f'{guard_var}_rank') < Int(f'{dfs_next_stage._varlist[0]}_rank'))
s.add(
Int(f"{guard_var}_rank")
< Int(f"{dfs_next_stage._varlist[0]}_rank")
)

# Find out the stage that would follow the current stage in an actual run.
true_next_stage = dfs_next_stage
if not(stage.is_sibling_of(dfs_next_stage) or stage.is_ancestor_of(dfs_next_stage)):
if not (
stage.is_sibling_of(dfs_next_stage)
or stage.is_ancestor_of(dfs_next_stage)
):
ancestor = stage._parent
while ancestor is not None:
if (not ancestor.has_condition()) and ancestor._right is not None:
Expand All @@ -169,14 +181,15 @@ def _determine_variable_ranks(self, stage, cvars):
ancestor = ancestor._parent

# All variables in the current stage must have a lower rank than the variables in the (true) next stage:
s.add(curr_rank < Int(f'{true_next_stage._varlist[0]}_rank'))
s.add(curr_rank < Int(f"{true_next_stage._varlist[0]}_rank"))

for aux_var in stage._aux_varlist:
# All auxiliary variables must have a higher rank than the variables in the current stage:
s.add(curr_rank < Int(f'{aux_var}_rank'))
s.add(curr_rank < Int(f"{aux_var}_rank"))
# All auxiliary variables must have a lower rank than the variables in the (true) next stage:
s.add(Int(f'{aux_var}_rank') < Int(f'{true_next_stage._varlist[0]}_rank'))

s.add(
Int(f"{aux_var}_rank") < Int(f"{true_next_stage._varlist[0]}_rank")
)

# Check if the current stage is consistent
if s.check() == unsat:
Expand All @@ -185,6 +198,15 @@ def _determine_variable_ranks(self, stage, cvars):
# continue dfs traversal:
stage = dfs_next_stage

# Also take options dependencies into account
for var in cvars.values():
for dependent_var in var._dependent_vars:
s.add(Int(f"{var}_rank") < Int(f"{dependent_var}_rank"))
if s.check() == unsat:
raise RuntimeError(
"Inconsistent variable ranks encountered due to options dependencies."
)

# Now minimize the maximum rank (This is optional and can be removed if performance becomes an issue)
opt = Optimize()
opt.add(s.assertions())
Expand All @@ -194,18 +216,19 @@ def _determine_variable_ranks(self, stage, cvars):

for var in cvars:
try:
cvars[var].rank = model.eval(Int(f'{var}_rank')).as_long()
cvars[var].rank = model.eval(Int(f"{var}_rank")).as_long()
except AttributeError:
# This variable is not contained by any stage. Set its rank to max_rank + 1
cvars[var].rank = model.eval(Int('max_rank')).as_long() + 1
cvars[var].rank = model.eval(Int("max_rank")).as_long() + 1

def _process_relational_constraints(self, cvars):
"""Process the relational constraints to construct a constraint graph and add constraints
to the solver. The constraint graph is a directed graph where the nodes are the variables
and the edges are (one or more) relational constraints that connect the variables."""
and the edges are (one or more) relational constraints that connect the variables.
"""

# constraint graph
self._cgraph = {var : set() for var in cvars.values()}
self._cgraph = {var: set() for var in cvars.values()}

warn = (
"The relational_constraints must be a dictionary where keys are the z3 boolean expressions "
Expand All @@ -229,7 +252,13 @@ def _process_relational_constraints(self, cvars):
constr_vars = {cvars[var.sexpr()] for var in z3util.get_vars(constr)}

for var in constr_vars:
self._cgraph[var].update(constr_vars - {var})
self._cgraph[var].update(
set(
var_other
for var_other in constr_vars
if var_other is not var and var_other.rank >= var.rank
)
)

@property
def initialized(self):
Expand Down Expand Up @@ -264,7 +293,9 @@ def check_assignment(self, var, new_value):

# Sanity checks
assert self._initialized, "Must finalize initialization to check assignments."
assert self._checked_assignment is None, "A check/register cycle is in progress."
assert (
self._checked_assignment is None
), "A check/register cycle is in progress."
assert new_value is not None, "None is always a valid assignment."

# Depending on the domain of the variable, check the assignment
Expand Down Expand Up @@ -294,7 +325,6 @@ def _check_assignment_of_finite_domain_var(self, var, new_value):
raise ConstraintViolation(self.retrieve_error_msg(var, new_val))
if validity is None:
raise ConstraintViolation(f"{new_val} not an option for {var}")


def _check_assignment_of_infinite_domain_var(self, var, new_value):
"""Check the assignment of a variable with an infinite domain to a new value. The check
Expand Down Expand Up @@ -340,7 +370,7 @@ def _check_assignment_of_infinite_domain_var(self, var, new_value):
# Set variable value to None, and raise an exception.
var.value = None
raise ConstraintViolation(
f"Your current configuration settings have created infeasible options for future settings. "\
f"Your current configuration settings have created infeasible options for future settings. "
"Please reset or revise your selections."
)

Expand Down Expand Up @@ -421,14 +451,14 @@ def retrieve_error_msg(self, var, new_value):
)

error_messages = [str(err_msg) for err_msg in s.unsat_core()]
msg = f'Invalid assignment of {var} to {new_value}.'
msg = f"Invalid assignment of {var} to {new_value}."
if len(error_messages) == 1:
msg += f' Reason: {error_messages[0]}'
msg += f" Reason: {error_messages[0]}"
else:
msg +=' Reasons:'
msg += " Reasons:"
for i, err_msg in enumerate(error_messages):
msg += f' {i+1}: {err_msg}.'
msg = msg.replace('..', '.')
msg += f" {i+1}: {err_msg}."
msg = msg.replace("..", ".")
return msg

def register_assignment(self, var, new_value):
Expand All @@ -450,17 +480,20 @@ def register_assignment(self, var, new_value):

logger.debug(f"Registering assignment of {var} to {new_value}.")
if new_value is not None:
assert self._checked_assignment == (var, new_value), (
"The assignment to be registered does not match the latest checked assignment."
)
assert self._checked_assignment == (
var,
new_value,
), "The assignment to be registered does not match the latest checked assignment."
# Handshake complete. Reset the checked assignment:
self._checked_assignment = None

if not (var.has_dependent_vars() or self._cgraph[var] or var.is_guard_var):
logger.debug("%s has no dependent or related variables. Returning.", var)
return

assert not self._tlock.is_locked(), "Traversal lock is acquired. Cannot register assignment."
assert (
not self._tlock.is_locked()
), "Traversal lock is acquired. Cannot register assignment."

with self._tlock: # acquire the lock to detect recursive traversal of constraint hypergraph

Expand Down Expand Up @@ -496,7 +529,9 @@ def _update_options_of_dependent_vars(var, new_value):
"""

if new_value is None:
new_options_and_tooltips = {dependent_var: (None, None) for dependent_var in var._dependent_vars}
new_options_and_tooltips = {
dependent_var: (None, None) for dependent_var in var._dependent_vars
}
else:
new_options_and_tooltips = {}
for dependent_var in var._dependent_vars:
Expand All @@ -507,7 +542,7 @@ def _update_options_of_dependent_vars(var, new_value):
new_tooltips,
)
# Note: For variables with infinite domain, the options_spec methods are called both
# within check_assignment and register_assignment. This doesn't appear to lead to
# within check_assignment and register_assignment. This doesn't appear to lead to
# noticeable performance issues, but it may be worth revisiting in the future.

for dependent_var, (
Expand Down Expand Up @@ -535,7 +570,7 @@ def _refresh_options_validities(self, var):
visited = set()

# Queue of variables to be visited
queue = [neig for neig in self._cgraph[var] if neig.has_options() and var.rank <= neig.rank]
queue = [neig for neig in self._cgraph[var] if neig.has_options()]

# Traverse the constraint graph to refresh the options validities of all possibly affected variables
while queue:
Expand All @@ -552,7 +587,7 @@ def _refresh_options_validities(self, var):
[
neig
for neig in self._cgraph[var]
if neig.has_options() and var.rank <= neig.rank and neig not in visited
if neig.has_options() and neig not in visited
]
)

Expand Down
76 changes: 76 additions & 0 deletions tools/cgraph_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from ProConPy.config_var import ConfigVar, cvars
from ProConPy.stage import Stage
from ProConPy.csp_solver import csp
from visualCaseGen.cime_interface import CIME_interface
from visualCaseGen.initialize_configvars import initialize_configvars
from visualCaseGen.initialize_widgets import initialize_widgets
from visualCaseGen.initialize_stages import initialize_stages
from visualCaseGen.specs.options import set_options
from visualCaseGen.specs.relational_constraints import get_relational_constraints

import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout


def initialize(cime):
"""Initializes visualCaseGen"""
ConfigVar.reboot()
Stage.reboot()
initialize_configvars(cime)
initialize_widgets(cime)
initialize_stages(cime)
set_options(cime)
csp.initialize(cvars, get_relational_constraints(cvars), Stage.first())


def gen_cgraph():
"""Generates the constraint graph based on relational constraints and dependent variables."""

G = nx.DiGraph()
for _, cvar in cvars.items():
for related_var in csp._cgraph[cvar]:
G.add_edge(cvar, related_var)
for dependent_var in cvar._dependent_vars:
G.add_edge(cvar, dependent_var)

# TODO: remove this manual addition and make sure it is added automatically.
G.add_edge(cvars["COMPSET_ALIAS"], cvars["COMPSET_LNAME"])

return G


def plot_cgraph():
"""Plots the constraint graph."""

G = gen_cgraph()
pos = graphviz_layout(G, prog="sfdp")
nx.draw(
G,
pos,
with_labels=False,
node_size=100,
node_color="skyblue",
font_color="black",
edge_color="gray",
linewidths=0.5,
width=0.5,
alpha=0.5,
)

text = nx.draw_networkx_labels(G, pos)
for _, t in text.items():
# t.set_rotation(20)
# t.set_verticalalignment("center_baseline")
t.set_fontsize(8)

plt.show()


def main():
initialize(CIME_interface())
plot_cgraph()


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion tools/stage_tree_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from visualCaseGen.specs.relational_constraints import get_relational_constraints

import networkx as nx
from networkx.drawing.nx_pydot import graphviz_layout
import matplotlib.pyplot as plt


Expand Down

0 comments on commit 140b7ac

Please sign in to comment.