Skip to content

Commit

Permalink
condition: support graph-scheduler graph structure conditions
Browse files Browse the repository at this point in the history
- modify Scheduler wrappers to support both basic and structural
Conditions
- add wrappers for new graph_scheduler.Scheduler methods:
	- remove_condition
	- add_graph_edge
	- remove_graph_edge
- update requirements to graph-scheduler<1.3.0 to include graph
structure conditions release
  • Loading branch information
kmantel committed Dec 14, 2023
1 parent e10cc07 commit 91b9b98
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 50 deletions.
118 changes: 84 additions & 34 deletions psyneulink/core/scheduling/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import collections
import copy
import functools
import inspect
import numbers
import warnings
Expand All @@ -23,12 +22,44 @@
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.keywords import MODEL_SPEC_ID_TYPE, comparison_operators
from psyneulink.core.globals.parameters import parse_context
from psyneulink.core.globals.utilities import parse_valid_identifier
from psyneulink.core.globals.utilities import parse_valid_identifier, toposort_key

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
psyneulink.core.globals.utilities
begins an import cycle.

__all__ = copy.copy(graph_scheduler.condition.__all__)
__all__.extend(['Threshold'])


# avoid restricting graph_scheduler versions for this code
# ConditionBase was introduced with graph structure conditions
gs_condition_base_class = graph_scheduler.condition.Condition
condition_class_parents = [graph_scheduler.condition.Condition]


try:
gs_condition_base_class = graph_scheduler.condition.ConditionBase
except AttributeError:
pass
else:
class ConditionBase(graph_scheduler.condition.ConditionBase, MDFSerializable):
def as_mdf_model(self):
raise graph_scheduler.ConditionError(
f'MDF support not yet implemented for {type(self)}'
)
condition_class_parents.append(ConditionBase)


try:
graph_scheduler.condition.GraphStructureCondition
except AttributeError:
graph_structure_conditions_available = False
gsc_unavailable_message = (
'Graph structure conditions are not available'
f'in your installed graph-scheduler v{graph_scheduler.__version__}'
)
else:
graph_structure_conditions_available = True
gsc_unavailable_message = ''


def _create_as_pnl_condition(condition):
import psyneulink as pnl

Expand All @@ -38,12 +69,24 @@ def _create_as_pnl_condition(condition):
return condition

# already a pnl Condition
if isinstance(condition, Condition):
if isinstance(condition, pnl_condition_base_class):
return condition

if not issubclass(pnl_class, graph_scheduler.Condition):
if not issubclass(pnl_class, gs_condition_base_class):
return None

if (
graph_structure_conditions_available
and isinstance(condition, graph_scheduler.condition.GraphStructureCondition)
):
try:
return pnl_class(
*condition.nodes,
**{k: v for k, v in condition.kwargs.items() if k != 'nodes'}
)
except AttributeError:
return pnl_class(**condition.kwargs)

new_args = [_create_as_pnl_condition(a) or a for a in condition.args]
new_kwargs = {k: _create_as_pnl_condition(v) or v for k, v in condition.kwargs.items()}
sig = inspect.signature(pnl_class)
Expand All @@ -58,7 +101,7 @@ def _create_as_pnl_condition(condition):
return res


class Condition(graph_scheduler.Condition, MDFSerializable):
class Condition(*condition_class_parents, MDFSerializable):
@handle_external_context()
def is_satisfied(self, *args, context=None, execution_id=None, **kwargs):
if execution_id is None:
Expand All @@ -81,7 +124,7 @@ def as_mdf_model(self):
def _parse_condition_arg(arg):
if isinstance(arg, Component):
return parse_valid_identifier(arg.name)
elif isinstance(arg, graph_scheduler.Condition):
elif isinstance(arg, Condition):
return arg.as_mdf_model()
elif arg is None or isinstance(arg, numbers.Number):
return arg
Expand Down Expand Up @@ -112,7 +155,7 @@ def _parse_condition_arg(arg):
for a in self.args:
if isinstance(a, Component):
a = parse_valid_identifier(a.name)
elif isinstance(a, graph_scheduler.Condition):
elif isinstance(a, Condition):
a = a.as_mdf_model()
args_list.append(a)
extra_args[name] = args_list
Expand All @@ -139,40 +182,47 @@ def _parse_condition_arg(arg):
# below produces psyneulink versions of each Condition class so that
# they are compatible with the extra changes made in Condition above
# (the scheduler does not handle Context objects or mdf/json export)
cond_dependencies = {}
gs_class_dependencies = {}
gs_classes_to_copy_as_pnl = []
pnl_conditions_module = locals() # inserting into locals defines the classes
pnl_condition_base_class = pnl_conditions_module[gs_condition_base_class.__name__]

for class_name in graph_scheduler.condition.__dict__:
cls_ = getattr(graph_scheduler.condition, class_name)
if inspect.isclass(cls_):
# don't substitute classes explicitly defined above
if class_name not in pnl_conditions_module:
if issubclass(cls_, gs_condition_base_class):
gs_classes_to_copy_as_pnl.append(class_name)
else:
pnl_conditions_module[class_name] = cls_

for cond_name in graph_scheduler.condition.__all__:
sched_module_cond_obj = getattr(graph_scheduler.condition, cond_name)
cond_dependencies[cond_name] = {c.__name__ for c in sched_module_cond_obj.__mro__ if c.__name__ != cond_name}
gs_class_dependencies[class_name] = {
c.__name__ for c in cls_.__mro__ if c.__name__ != class_name
}

# iterate in order such that superclass types are before subclass types
for cond_name in sorted(
graph_scheduler.condition.__all__,
key=functools.cmp_to_key(lambda a, b: -1 if b in cond_dependencies[a] else 1)
gs_classes_to_copy_as_pnl,
key=toposort_key(gs_class_dependencies)
):
# don't substitute Condition because it is explicitly defined above
if cond_name == 'Condition':
continue

sched_module_cond_obj = getattr(graph_scheduler.condition, cond_name)
if (
inspect.isclass(sched_module_cond_obj)
and issubclass(sched_module_cond_obj, graph_scheduler.Condition)
):
new_mro = []
for cls_ in sched_module_cond_obj.__mro__:
if cls_ is not graph_scheduler.Condition:
try:
new_mro.append(pnl_conditions_module[cls_.__name__])

except KeyError:
new_mro.append(cls_)
else:
new_mro.extend(Condition.__mro__[:-1])
pnl_conditions_module[cond_name] = type(cond_name, tuple(new_mro), {})
elif isinstance(sched_module_cond_obj, type):
pnl_conditions_module[cond_name] = sched_module_cond_obj
new_bases = []
for cls_ in sched_module_cond_obj.__mro__:
try:
new_bases.append(pnl_conditions_module[cls_.__name__])
except KeyError:
new_bases.append(cls_)
if cls_ is gs_condition_base_class:
break

new_meta = type(new_bases[0])
if new_meta is not type:
pnl_conditions_module[cond_name] = new_meta(
cond_name, tuple(new_bases), {'__module__': Condition.__module__}
)
else:
pnl_conditions_module[cond_name] = type(cond_name, tuple(new_bases), {})

pnl_conditions_module[cond_name].__doc__ = sched_module_cond_obj.__doc__

Expand Down
63 changes: 50 additions & 13 deletions psyneulink/core/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import copy
import logging
import typing
from typing import Hashable

import graph_scheduler
import pint

import psyneulink as pnl

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'psyneulink' is imported with both 'import' and 'import from'.
from psyneulink import _unit_registry
from psyneulink.core.globals.context import Context, handle_external_context
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.utilities import parse_valid_identifier
from psyneulink.core.scheduling.condition import _create_as_pnl_condition
from psyneulink.core.scheduling.condition import _create_as_pnl_condition, graph_structure_conditions_available, gsc_unavailable_message

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
psyneulink.core.scheduling.condition
begins an import cycle.

__all__ = [
'Scheduler', 'SchedulingMode'
Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(
graph = composition.graph_processing.prune_feedback_edges()[0]
if default_execution_id is None:
default_execution_id = composition.default_execution_id
self.composition = composition

# TODO: consider integrating something like this into graph-scheduler?
self._user_specified_conds = graph_scheduler.ConditionSet()
Expand Down Expand Up @@ -77,8 +80,15 @@ def replace_term_conds(term_conds):

def _validate_conditions(self):
unspecified_nodes = []

# pre-graph-structure-condition compatibility
try:
conditions_basic = self.conditions.conditions_basic
except AttributeError:
conditions_basic = self.conditions.conditions

for node in self.nodes:
if node not in self.conditions:
if node not in conditions_basic:
dependencies = list(self.dependency_dict[node])
if len(dependencies) == 0:
cond = graph_scheduler.Always()
Expand All @@ -90,12 +100,8 @@ def _validate_conditions(self):
# TODO: replace this call in graph-scheduler if adding _user_specified_conds
self._add_condition(node, cond)
unspecified_nodes.append(node)
if len(unspecified_nodes) > 0:
logger.info(
'These nodes have no Conditions specified, and will be scheduled with conditions: {0}'.format(
{node: self.conditions[node] for node in unspecified_nodes}
)
)

super()._validate_conditions()

def add_condition(self, owner, condition):
self._user_specified_conds.add_condition(owner, condition)
Expand All @@ -105,22 +111,37 @@ def _add_condition(self, owner, condition):
condition = _create_as_pnl_condition(condition)
super().add_condition(owner, condition)

if graph_structure_conditions_available:
if isinstance(condition, pnl.GraphStructureCondition):
self.composition._analyze_graph()

def add_condition_set(self, conditions):
self._user_specified_conds.add_condition_set(conditions)
self._add_condition_set(conditions)

def _add_condition_set(self, conditions):
try:
conditions = conditions.conditions
except AttributeError:
pass

conditions = {
node: _create_as_pnl_condition(conditions[node])
for node in conditions
}
super().add_condition_set(conditions)

def remove_condition(self, owner_or_condition):
try:
res = super().remove_condition(owner_or_condition)
except AttributeError as e:
if "has no attribute 'remove_condition'" in str(e):
raise graph_scheduler.SchedulerError(
f'remove_condition unavailable in your installed graph-scheduler v{graph_scheduler.__version__}'
)
else:
raise
else:
if isinstance(res, pnl.GraphStructureCondition):
self.composition._analyze_graph()

return res

@graph_scheduler.Scheduler.termination_conds.setter
def termination_conds(self, termination_conds):
if termination_conds is not None:
Expand Down Expand Up @@ -160,6 +181,22 @@ def as_mdf_model(self):
def get_clock(self, context):
return super().get_clock(context.execution_id)

def add_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.AddEdgeTo':
if not graph_structure_conditions_available:
raise graph_scheduler.SchedulerError(gsc_unavailable_message)

cond = pnl.AddEdgeTo(receiver)
self.add_condition(sender, cond)
return cond

def remove_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.RemoveEdgeFrom':
if not graph_structure_conditions_available:
raise graph_scheduler.SchedulerError(gsc_unavailable_message)

cond = pnl.RemoveEdgeFrom(sender)
self.add_condition(receiver, cond)
return cond


_doc_subs = {
None: [
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ autograd<1.7
beartype<0.16.0
dill<0.3.8
fastkde>=1.0.24, <1.0.31
graph-scheduler>=1.1.1, <1.1.3
graph-scheduler>=1.1.1, <1.3.0
graphviz<0.21.0
grpcio<1.60.0
leabra-psyneulink<0.3.3
Expand Down
19 changes: 18 additions & 1 deletion tests/composition/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
NAME, PROJECTIONS, RESULT, OBJECTIVE_MECHANISM, OUTPUT_MECHANISM, OVERRIDE,
PARAMS, SLOPE, TARGET_MECHANISM,
VARIABLE, VARIANCE)
from psyneulink.core.scheduling.condition import AtTimeStep, AtTrial, Never, TimeInterval
from psyneulink.core.scheduling.condition import AtTimeStep, AtTrial, Never, TimeInterval, graph_structure_conditions_available, gsc_unavailable_message
from psyneulink.core.scheduling.condition import EveryNCalls
from psyneulink.core.scheduling.scheduler import Scheduler, SchedulingMode
from psyneulink.core.scheduling.time import TimeScale
Expand Down Expand Up @@ -7671,6 +7671,23 @@ def test_feedback_projection_added_by_pathway(self):
B: {NodeRole.TERMINAL, NodeRole.OUTPUT, NodeRole.FEEDBACK_SENDER},
}

@pytest.mark.skipif(
not graph_structure_conditions_available,
reason=gsc_unavailable_message
)
def test_graph_structure_condition_role_changes(self):
A = pnl.ProcessingMechanism(name='A')
B = pnl.ProcessingMechanism(name='B')
C = pnl.ProcessingMechanism(name='C')

comp = Composition(pathways=[A, B, C])

comp.scheduler.add_condition(C, pnl.BeforeNode(A))

assert comp.nodes_to_roles[A] == {NodeRole.INPUT}
assert comp.nodes_to_roles[B] == {NodeRole.INTERNAL, NodeRole.TERMINAL}
assert comp.nodes_to_roles[C] == {NodeRole.OUTPUT, NodeRole.ORIGIN}


class TestMisc:

Expand Down
22 changes: 22 additions & 0 deletions tests/scheduling/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,25 @@ def three_node_linear_composition():
comp.add_linear_processing_pathway([A, B, C])

return comp.nodes, comp


@pytest.helpers.register
def composition_from_string_pathways(pathways):
mechanisms = {}
pathways_as_mechs = []

for p in pathways:
p_as_mechs = []
assert not isinstance(p, str), 'pathways must be a list of lists'
for m in p:
try:
mech = mechanisms[m]
except KeyError:
mech = pnl.ProcessingMechanism(name=m)
mechanisms[m] = mech
p_as_mechs.append(mech)
pathways_as_mechs.append(p_as_mechs)

comp = pnl.Composition(pathways=pathways_as_mechs)

return comp, mechanisms, mechanisms.values()
Loading

0 comments on commit 91b9b98

Please sign in to comment.