Skip to content

Commit

Permalink
Merge pull request #2864 from kmantel/structural-conditions
Browse files Browse the repository at this point in the history
Support graph-scheduler graph structure conditions
  • Loading branch information
kmantel authored Dec 15, 2023
2 parents 1b49504 + 1869c2e commit a38768d
Show file tree
Hide file tree
Showing 10 changed files with 380 additions and 64 deletions.
29 changes: 29 additions & 0 deletions psyneulink/core/globals/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@
import collections
import copy
import inspect
import itertools
import logging
import psyneulink
import re
import time
import warnings
import weakref
import toposort
import types
import typing
from beartype import beartype
Expand Down Expand Up @@ -2060,3 +2062,30 @@ def get_function_sig_default_value(
return sig.parameters[parameter].default
except KeyError:
return inspect._empty


def toposort_key(
dependency_dict: typing.Dict[typing.Hashable, typing.Iterable[typing.Any]]
) -> typing.Callable[[typing.Any], int]:
"""
Creates a key function for python sorting that causes all items in
**dependency_dict** to be sorted after their dependencies
Args:
dependency_dict (typing.Dict[typing.Hashable, typing.Iterable[typing.Any]]):
a dictionary where values are the dependencies of keys
Returns:
typing.Callable[[typing.Any], int]: a key function for python
sorting
"""
topo_ordering = list(toposort.toposort(dependency_dict))
topo_ordering = list(itertools.chain.from_iterable(topo_ordering))

def _generated_toposort_key(obj):
try:
return topo_ordering.index(obj)
except ValueError:
return -1

return _generated_toposort_key
10 changes: 7 additions & 3 deletions psyneulink/core/scheduling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@

if cls.__doc__ is None:
try:
cls.__doc__ = f'{getattr(ext_module, cls_name).__doc__}'
ext_cls = getattr(ext_module, cls_name)
except AttributeError:
# PNL-exclusive object
continue
else:
cls.__doc__ = ext_cls.__doc__

cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)
if cls.__doc__ is not None:
cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)

for cls, repls in module._doc_subs.items():
if cls is None:
Expand All @@ -73,7 +76,8 @@
cls = getattr(module, cls)

for pattern, repl in repls:
cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)
if cls.__doc__ is not None:
cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)

del graph_scheduler
del re
118 changes: 84 additions & 34 deletions psyneulink/core/scheduling/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import collections
import copy
import functools
import inspect
import numbers
import warnings
Expand All @@ -23,12 +22,44 @@
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.keywords import MODEL_SPEC_ID_TYPE, comparison_operators
from psyneulink.core.globals.parameters import parse_context
from psyneulink.core.globals.utilities import parse_valid_identifier
from psyneulink.core.globals.utilities import parse_valid_identifier, toposort_key

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
psyneulink.core.globals.utilities
begins an import cycle.

__all__ = copy.copy(graph_scheduler.condition.__all__)
__all__.extend(['Threshold'])


# avoid restricting graph_scheduler versions for this code
# ConditionBase was introduced with graph structure conditions
gs_condition_base_class = graph_scheduler.condition.Condition
condition_class_parents = [graph_scheduler.condition.Condition]


try:
gs_condition_base_class = graph_scheduler.condition.ConditionBase
except AttributeError:
pass
else:
class ConditionBase(graph_scheduler.condition.ConditionBase, MDFSerializable):
def as_mdf_model(self):
raise graph_scheduler.ConditionError(
f'MDF support not yet implemented for {type(self)}'
)
condition_class_parents.append(ConditionBase)


try:
graph_scheduler.condition.GraphStructureCondition
except AttributeError:
graph_structure_conditions_available = False
gsc_unavailable_message = (
'Graph structure conditions are not available'
f'in your installed graph-scheduler v{graph_scheduler.__version__}'
)
else:
graph_structure_conditions_available = True
gsc_unavailable_message = ''


def _create_as_pnl_condition(condition):
import psyneulink as pnl

Expand All @@ -38,12 +69,24 @@ def _create_as_pnl_condition(condition):
return condition

# already a pnl Condition
if isinstance(condition, Condition):
if isinstance(condition, pnl_condition_base_class):
return condition

if not issubclass(pnl_class, graph_scheduler.Condition):
if not issubclass(pnl_class, gs_condition_base_class):
return None

if (
graph_structure_conditions_available
and isinstance(condition, graph_scheduler.condition.GraphStructureCondition)
):
try:
return pnl_class(
*condition.nodes,
**{k: v for k, v in condition.kwargs.items() if k != 'nodes'}
)
except AttributeError:
return pnl_class(**condition.kwargs)

new_args = [_create_as_pnl_condition(a) or a for a in condition.args]
new_kwargs = {k: _create_as_pnl_condition(v) or v for k, v in condition.kwargs.items()}
sig = inspect.signature(pnl_class)
Expand All @@ -58,7 +101,7 @@ def _create_as_pnl_condition(condition):
return res


class Condition(graph_scheduler.Condition, MDFSerializable):
class Condition(*condition_class_parents, MDFSerializable):
@handle_external_context()
def is_satisfied(self, *args, context=None, execution_id=None, **kwargs):
if execution_id is None:
Expand All @@ -81,7 +124,7 @@ def as_mdf_model(self):
def _parse_condition_arg(arg):
if isinstance(arg, Component):
return parse_valid_identifier(arg.name)
elif isinstance(arg, graph_scheduler.Condition):
elif isinstance(arg, Condition):
return arg.as_mdf_model()
elif arg is None or isinstance(arg, numbers.Number):
return arg
Expand Down Expand Up @@ -112,7 +155,7 @@ def _parse_condition_arg(arg):
for a in self.args:
if isinstance(a, Component):
a = parse_valid_identifier(a.name)
elif isinstance(a, graph_scheduler.Condition):
elif isinstance(a, Condition):
a = a.as_mdf_model()
args_list.append(a)
extra_args[name] = args_list
Expand All @@ -139,40 +182,47 @@ def _parse_condition_arg(arg):
# below produces psyneulink versions of each Condition class so that
# they are compatible with the extra changes made in Condition above
# (the scheduler does not handle Context objects or mdf/json export)
cond_dependencies = {}
gs_class_dependencies = {}
gs_classes_to_copy_as_pnl = []
pnl_conditions_module = locals() # inserting into locals defines the classes
pnl_condition_base_class = pnl_conditions_module[gs_condition_base_class.__name__]

for class_name in graph_scheduler.condition.__dict__:
cls_ = getattr(graph_scheduler.condition, class_name)
if inspect.isclass(cls_):
# don't substitute classes explicitly defined above
if class_name not in pnl_conditions_module:
if issubclass(cls_, gs_condition_base_class):
gs_classes_to_copy_as_pnl.append(class_name)
else:
pnl_conditions_module[class_name] = cls_

for cond_name in graph_scheduler.condition.__all__:
sched_module_cond_obj = getattr(graph_scheduler.condition, cond_name)
cond_dependencies[cond_name] = set(sched_module_cond_obj.__mro__[1:])
gs_class_dependencies[class_name] = {
c.__name__ for c in cls_.__mro__ if c.__name__ != class_name
}

# iterate in order such that superclass types are before subclass types
for cond_name in sorted(
graph_scheduler.condition.__all__,
key=functools.cmp_to_key(lambda a, b: -1 if b in cond_dependencies[a] else 1)
gs_classes_to_copy_as_pnl,
key=toposort_key(gs_class_dependencies)
):
# don't substitute Condition because it is explicitly defined above
if cond_name == 'Condition':
continue

sched_module_cond_obj = getattr(graph_scheduler.condition, cond_name)
if (
inspect.isclass(sched_module_cond_obj)
and issubclass(sched_module_cond_obj, graph_scheduler.Condition)
):
new_mro = []
for cls_ in sched_module_cond_obj.__mro__:
if cls_ is not graph_scheduler.Condition:
try:
new_mro.append(pnl_conditions_module[cls_.__name__])

except KeyError:
new_mro.append(cls_)
else:
new_mro.extend(Condition.__mro__[:-1])
pnl_conditions_module[cond_name] = type(cond_name, tuple(new_mro), {})
elif isinstance(sched_module_cond_obj, type):
pnl_conditions_module[cond_name] = sched_module_cond_obj
new_bases = []
for cls_ in sched_module_cond_obj.__mro__:
try:
new_bases.append(pnl_conditions_module[cls_.__name__])
except KeyError:
new_bases.append(cls_)
if cls_ is gs_condition_base_class:
break

new_meta = type(new_bases[0])
if new_meta is not type:
pnl_conditions_module[cond_name] = new_meta(
cond_name, tuple(new_bases), {'__module__': Condition.__module__}
)
else:
pnl_conditions_module[cond_name] = type(cond_name, tuple(new_bases), {})

pnl_conditions_module[cond_name].__doc__ = sched_module_cond_obj.__doc__

Expand Down
75 changes: 57 additions & 18 deletions psyneulink/core/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import copy
import logging
import typing
from typing import Hashable

import graph_scheduler
import pint

import psyneulink as pnl

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'psyneulink' is imported with both 'import' and 'import from'.
from psyneulink import _unit_registry
from psyneulink.core.globals.context import Context, handle_external_context
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.utilities import parse_valid_identifier
from psyneulink.core.scheduling.condition import _create_as_pnl_condition
from psyneulink.core.scheduling.condition import _create_as_pnl_condition, graph_structure_conditions_available, gsc_unavailable_message

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
psyneulink.core.scheduling.condition
begins an import cycle.

__all__ = [
'Scheduler', 'SchedulingMode'
Expand Down Expand Up @@ -50,9 +52,12 @@ def __init__(
graph = composition.graph_processing.prune_feedback_edges()[0]
if default_execution_id is None:
default_execution_id = composition.default_execution_id
self.composition = composition

# TODO: consider integrating something like this into graph-scheduler?
self._user_specified_conds = copy.copy(conditions) if conditions is not None else {}
self._user_specified_conds = graph_scheduler.ConditionSet()
if conditions is not None:
self._user_specified_conds.add_condition_set(copy.copy(conditions))
self._user_specified_termination_conds = copy.copy(termination_conds) if termination_conds is not None else {}

super().__init__(
Expand All @@ -75,8 +80,15 @@ def replace_term_conds(term_conds):

def _validate_conditions(self):
unspecified_nodes = []

# pre-graph-structure-condition compatibility
try:
conditions_basic = self.conditions.conditions_basic
except AttributeError:
conditions_basic = self.conditions.conditions

for node in self.nodes:
if node not in self.conditions:
if node not in conditions_basic:
dependencies = list(self.dependency_dict[node])
if len(dependencies) == 0:
cond = graph_scheduler.Always()
Expand All @@ -88,37 +100,48 @@ def _validate_conditions(self):
# TODO: replace this call in graph-scheduler if adding _user_specified_conds
self._add_condition(node, cond)
unspecified_nodes.append(node)
if len(unspecified_nodes) > 0:
logger.info(
'These nodes have no Conditions specified, and will be scheduled with conditions: {0}'.format(
{node: self.conditions[node] for node in unspecified_nodes}
)
)

super()._validate_conditions()

def add_condition(self, owner, condition):
self._user_specified_conds[owner] = condition
self._user_specified_conds.add_condition(owner, condition)
self._add_condition(owner, condition)

def _add_condition(self, owner, condition):
condition = _create_as_pnl_condition(condition)
super().add_condition(owner, condition)

if graph_structure_conditions_available:
if isinstance(condition, pnl.GraphStructureCondition):
self.composition._analyze_graph()

def add_condition_set(self, conditions):
self._user_specified_conds.update(conditions)
self._user_specified_conds.add_condition_set(conditions)
self._add_condition_set(conditions)

def _add_condition_set(self, conditions):
try:
conditions = conditions.conditions
except AttributeError:
pass

conditions = {
node: _create_as_pnl_condition(cond)
for node, cond in conditions.items()
node: _create_as_pnl_condition(conditions[node])
for node in conditions
}
super().add_condition_set(conditions)

def remove_condition(self, owner_or_condition):
try:
res = super().remove_condition(owner_or_condition)
except AttributeError as e:
if "has no attribute 'remove_condition'" in str(e):
raise graph_scheduler.SchedulerError(
f'remove_condition unavailable in your installed graph-scheduler v{graph_scheduler.__version__}'
)
else:
raise
else:
if isinstance(res, pnl.GraphStructureCondition):
self.composition._analyze_graph()

return res

@graph_scheduler.Scheduler.termination_conds.setter
def termination_conds(self, termination_conds):
if termination_conds is not None:
Expand Down Expand Up @@ -158,6 +181,22 @@ def as_mdf_model(self):
def get_clock(self, context):
return super().get_clock(context.execution_id)

def add_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.AddEdgeTo':
if not graph_structure_conditions_available:
raise graph_scheduler.SchedulerError(gsc_unavailable_message)

cond = pnl.AddEdgeTo(receiver)
self.add_condition(sender, cond)
return cond

def remove_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.RemoveEdgeFrom':
if not graph_structure_conditions_available:
raise graph_scheduler.SchedulerError(gsc_unavailable_message)

cond = pnl.RemoveEdgeFrom(sender)
self.add_condition(receiver, cond)
return cond


_doc_subs = {
None: [
Expand Down
Loading

0 comments on commit a38768d

Please sign in to comment.