Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graph-scheduler graph structure conditions #2864

Merged
merged 5 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions psyneulink/core/globals/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@
import collections
import copy
import inspect
import itertools
import logging
import psyneulink
import re
import time
import warnings
import weakref
import toposort
import types
import typing
from beartype import beartype
Expand Down Expand Up @@ -2060,3 +2062,30 @@ def get_function_sig_default_value(
return sig.parameters[parameter].default
except KeyError:
return inspect._empty


def toposort_key(
dependency_dict: typing.Dict[typing.Hashable, typing.Iterable[typing.Any]]
) -> typing.Callable[[typing.Any], int]:
"""
Creates a key function for python sorting that causes all items in
**dependency_dict** to be sorted after their dependencies

Args:
dependency_dict (typing.Dict[typing.Hashable, typing.Iterable[typing.Any]]):
a dictionary where values are the dependencies of keys

Returns:
typing.Callable[[typing.Any], int]: a key function for python
sorting
"""
topo_ordering = list(toposort.toposort(dependency_dict))
topo_ordering = list(itertools.chain.from_iterable(topo_ordering))

def _generated_toposort_key(obj):
try:
return topo_ordering.index(obj)
except ValueError:
return -1

return _generated_toposort_key
10 changes: 7 additions & 3 deletions psyneulink/core/scheduling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@

if cls.__doc__ is None:
try:
cls.__doc__ = f'{getattr(ext_module, cls_name).__doc__}'
ext_cls = getattr(ext_module, cls_name)
except AttributeError:
# PNL-exclusive object
continue
else:
cls.__doc__ = ext_cls.__doc__

cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)
if cls.__doc__ is not None:
cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)

for cls, repls in module._doc_subs.items():
if cls is None:
Expand All @@ -73,7 +76,8 @@
cls = getattr(module, cls)

for pattern, repl in repls:
cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)
if cls.__doc__ is not None:
cls.__doc__ = re.sub(pattern, repl, cls.__doc__, flags=re.MULTILINE | re.DOTALL)

del graph_scheduler
del re
118 changes: 84 additions & 34 deletions psyneulink/core/scheduling/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import collections
import copy
import functools
import inspect
import numbers
import warnings
Expand All @@ -23,12 +22,44 @@
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.keywords import MODEL_SPEC_ID_TYPE, comparison_operators
from psyneulink.core.globals.parameters import parse_context
from psyneulink.core.globals.utilities import parse_valid_identifier
from psyneulink.core.globals.utilities import parse_valid_identifier, toposort_key

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
psyneulink.core.globals.utilities
begins an import cycle.

__all__ = copy.copy(graph_scheduler.condition.__all__)
__all__.extend(['Threshold'])


# avoid restricting graph_scheduler versions for this code
# ConditionBase was introduced with graph structure conditions
gs_condition_base_class = graph_scheduler.condition.Condition
condition_class_parents = [graph_scheduler.condition.Condition]


try:
gs_condition_base_class = graph_scheduler.condition.ConditionBase
except AttributeError:
pass
else:
class ConditionBase(graph_scheduler.condition.ConditionBase, MDFSerializable):
def as_mdf_model(self):
raise graph_scheduler.ConditionError(
f'MDF support not yet implemented for {type(self)}'
)
condition_class_parents.append(ConditionBase)


try:
graph_scheduler.condition.GraphStructureCondition
except AttributeError:
graph_structure_conditions_available = False
gsc_unavailable_message = (
'Graph structure conditions are not available'
f'in your installed graph-scheduler v{graph_scheduler.__version__}'
)
else:
graph_structure_conditions_available = True
gsc_unavailable_message = ''


def _create_as_pnl_condition(condition):
import psyneulink as pnl

Expand All @@ -38,12 +69,24 @@
return condition

# already a pnl Condition
if isinstance(condition, Condition):
if isinstance(condition, pnl_condition_base_class):
return condition

if not issubclass(pnl_class, graph_scheduler.Condition):
if not issubclass(pnl_class, gs_condition_base_class):
return None

if (
graph_structure_conditions_available
and isinstance(condition, graph_scheduler.condition.GraphStructureCondition)
):
try:
return pnl_class(
*condition.nodes,
**{k: v for k, v in condition.kwargs.items() if k != 'nodes'}
)
except AttributeError:
return pnl_class(**condition.kwargs)

new_args = [_create_as_pnl_condition(a) or a for a in condition.args]
new_kwargs = {k: _create_as_pnl_condition(v) or v for k, v in condition.kwargs.items()}
sig = inspect.signature(pnl_class)
Expand All @@ -58,7 +101,7 @@
return res


class Condition(graph_scheduler.Condition, MDFSerializable):
class Condition(*condition_class_parents, MDFSerializable):
@handle_external_context()
def is_satisfied(self, *args, context=None, execution_id=None, **kwargs):
if execution_id is None:
Expand All @@ -81,7 +124,7 @@
def _parse_condition_arg(arg):
if isinstance(arg, Component):
return parse_valid_identifier(arg.name)
elif isinstance(arg, graph_scheduler.Condition):
elif isinstance(arg, Condition):
return arg.as_mdf_model()
elif arg is None or isinstance(arg, numbers.Number):
return arg
Expand Down Expand Up @@ -112,7 +155,7 @@
for a in self.args:
if isinstance(a, Component):
a = parse_valid_identifier(a.name)
elif isinstance(a, graph_scheduler.Condition):
elif isinstance(a, Condition):
a = a.as_mdf_model()
args_list.append(a)
extra_args[name] = args_list
Expand All @@ -139,40 +182,47 @@
# below produces psyneulink versions of each Condition class so that
# they are compatible with the extra changes made in Condition above
# (the scheduler does not handle Context objects or mdf/json export)
cond_dependencies = {}
gs_class_dependencies = {}
gs_classes_to_copy_as_pnl = []
pnl_conditions_module = locals() # inserting into locals defines the classes
pnl_condition_base_class = pnl_conditions_module[gs_condition_base_class.__name__]

for class_name in graph_scheduler.condition.__dict__:
cls_ = getattr(graph_scheduler.condition, class_name)
if inspect.isclass(cls_):
# don't substitute classes explicitly defined above
if class_name not in pnl_conditions_module:
if issubclass(cls_, gs_condition_base_class):
gs_classes_to_copy_as_pnl.append(class_name)
else:
pnl_conditions_module[class_name] = cls_

for cond_name in graph_scheduler.condition.__all__:
sched_module_cond_obj = getattr(graph_scheduler.condition, cond_name)
cond_dependencies[cond_name] = set(sched_module_cond_obj.__mro__[1:])
gs_class_dependencies[class_name] = {
c.__name__ for c in cls_.__mro__ if c.__name__ != class_name
}

# iterate in order such that superclass types are before subclass types
for cond_name in sorted(
graph_scheduler.condition.__all__,
key=functools.cmp_to_key(lambda a, b: -1 if b in cond_dependencies[a] else 1)
gs_classes_to_copy_as_pnl,
key=toposort_key(gs_class_dependencies)
):
# don't substitute Condition because it is explicitly defined above
if cond_name == 'Condition':
continue

sched_module_cond_obj = getattr(graph_scheduler.condition, cond_name)
if (
inspect.isclass(sched_module_cond_obj)
and issubclass(sched_module_cond_obj, graph_scheduler.Condition)
):
new_mro = []
for cls_ in sched_module_cond_obj.__mro__:
if cls_ is not graph_scheduler.Condition:
try:
new_mro.append(pnl_conditions_module[cls_.__name__])

except KeyError:
new_mro.append(cls_)
else:
new_mro.extend(Condition.__mro__[:-1])
pnl_conditions_module[cond_name] = type(cond_name, tuple(new_mro), {})
elif isinstance(sched_module_cond_obj, type):
pnl_conditions_module[cond_name] = sched_module_cond_obj
new_bases = []
for cls_ in sched_module_cond_obj.__mro__:
try:
new_bases.append(pnl_conditions_module[cls_.__name__])
except KeyError:
new_bases.append(cls_)
if cls_ is gs_condition_base_class:
break

new_meta = type(new_bases[0])
if new_meta is not type:
pnl_conditions_module[cond_name] = new_meta(
cond_name, tuple(new_bases), {'__module__': Condition.__module__}
)
else:
pnl_conditions_module[cond_name] = type(cond_name, tuple(new_bases), {})

pnl_conditions_module[cond_name].__doc__ = sched_module_cond_obj.__doc__

Expand Down
75 changes: 57 additions & 18 deletions psyneulink/core/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import copy
import logging
import typing
from typing import Hashable

import graph_scheduler
import pint

import psyneulink as pnl

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'psyneulink' is imported with both 'import' and 'import from'.
from psyneulink import _unit_registry
from psyneulink.core.globals.context import Context, handle_external_context
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.utilities import parse_valid_identifier
from psyneulink.core.scheduling.condition import _create_as_pnl_condition
from psyneulink.core.scheduling.condition import _create_as_pnl_condition, graph_structure_conditions_available, gsc_unavailable_message

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
psyneulink.core.scheduling.condition
begins an import cycle.

__all__ = [
'Scheduler', 'SchedulingMode'
Expand Down Expand Up @@ -50,9 +52,12 @@
graph = composition.graph_processing.prune_feedback_edges()[0]
if default_execution_id is None:
default_execution_id = composition.default_execution_id
self.composition = composition

# TODO: consider integrating something like this into graph-scheduler?
self._user_specified_conds = copy.copy(conditions) if conditions is not None else {}
self._user_specified_conds = graph_scheduler.ConditionSet()
if conditions is not None:
self._user_specified_conds.add_condition_set(copy.copy(conditions))
self._user_specified_termination_conds = copy.copy(termination_conds) if termination_conds is not None else {}

super().__init__(
Expand All @@ -75,8 +80,15 @@

def _validate_conditions(self):
unspecified_nodes = []

# pre-graph-structure-condition compatibility
try:
conditions_basic = self.conditions.conditions_basic
except AttributeError:
conditions_basic = self.conditions.conditions

for node in self.nodes:
if node not in self.conditions:
if node not in conditions_basic:
dependencies = list(self.dependency_dict[node])
if len(dependencies) == 0:
cond = graph_scheduler.Always()
Expand All @@ -88,37 +100,48 @@
# TODO: replace this call in graph-scheduler if adding _user_specified_conds
self._add_condition(node, cond)
unspecified_nodes.append(node)
if len(unspecified_nodes) > 0:
logger.info(
'These nodes have no Conditions specified, and will be scheduled with conditions: {0}'.format(
{node: self.conditions[node] for node in unspecified_nodes}
)
)

super()._validate_conditions()

def add_condition(self, owner, condition):
self._user_specified_conds[owner] = condition
self._user_specified_conds.add_condition(owner, condition)
self._add_condition(owner, condition)

def _add_condition(self, owner, condition):
condition = _create_as_pnl_condition(condition)
super().add_condition(owner, condition)

if graph_structure_conditions_available:
if isinstance(condition, pnl.GraphStructureCondition):
self.composition._analyze_graph()

def add_condition_set(self, conditions):
self._user_specified_conds.update(conditions)
self._user_specified_conds.add_condition_set(conditions)
self._add_condition_set(conditions)

def _add_condition_set(self, conditions):
try:
conditions = conditions.conditions
except AttributeError:
pass

conditions = {
node: _create_as_pnl_condition(cond)
for node, cond in conditions.items()
node: _create_as_pnl_condition(conditions[node])
for node in conditions
}
super().add_condition_set(conditions)

def remove_condition(self, owner_or_condition):
try:
res = super().remove_condition(owner_or_condition)
except AttributeError as e:
if "has no attribute 'remove_condition'" in str(e):
raise graph_scheduler.SchedulerError(
f'remove_condition unavailable in your installed graph-scheduler v{graph_scheduler.__version__}'
)
else:
raise
else:
if isinstance(res, pnl.GraphStructureCondition):
self.composition._analyze_graph()

return res

@graph_scheduler.Scheduler.termination_conds.setter
def termination_conds(self, termination_conds):
if termination_conds is not None:
Expand Down Expand Up @@ -158,6 +181,22 @@
def get_clock(self, context):
return super().get_clock(context.execution_id)

def add_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.AddEdgeTo':
if not graph_structure_conditions_available:
raise graph_scheduler.SchedulerError(gsc_unavailable_message)

cond = pnl.AddEdgeTo(receiver)
self.add_condition(sender, cond)
return cond

def remove_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.RemoveEdgeFrom':
if not graph_structure_conditions_available:
raise graph_scheduler.SchedulerError(gsc_unavailable_message)

cond = pnl.RemoveEdgeFrom(sender)
self.add_condition(receiver, cond)
return cond


_doc_subs = {
None: [
Expand Down
Loading
Loading