diff --git a/psyneulink/core/components/projections/projection.py b/psyneulink/core/components/projections/projection.py index cc1a1f81a53..6999bca6702 100644 --- a/psyneulink/core/components/projections/projection.py +++ b/psyneulink/core/components/projections/projection.py @@ -978,6 +978,18 @@ def _activate_for_compositions(self, composition): def _activate_for_all_compositions(self): self._activate_for_compositions(ConnectionInfo.ALL) + def _deactivate_for_compositions(self, composition): + try: + self.receiver.afferents_info[self].remove_composition(composition) + except KeyError: + warnings.warn(f'{self} was not active for {composition}') + + def _deactivate_for_all_compositions(self): + self._deactivate_for_all_compositions(ConnectionInfo.ALL) + + def is_active_in_composition(self, composition): + return self.receiver.afferents_info[self].is_active_in_composition(composition) + def _delete_projection(projection, context=None): """Delete Projection, its entries in receiver and sender Ports, and in ProjectionRegistry""" projection.sender._remove_projection_from_port(projection) diff --git a/psyneulink/core/compositions/composition.py b/psyneulink/core/compositions/composition.py index 01c8b69ff29..975deb853a8 100644 --- a/psyneulink/core/compositions/composition.py +++ b/psyneulink/core/compositions/composition.py @@ -3880,11 +3880,12 @@ def scheduler(self): """ if self.needs_update_scheduler or not isinstance(self._scheduler, Scheduler): old_scheduler = self._scheduler - self._scheduler = Scheduler(composition=self) - if old_scheduler is not None: - self._scheduler.add_condition_set(old_scheduler.conditions) + orig_conds = old_scheduler._user_specified_conds + else: + orig_conds = None + self._scheduler = Scheduler(composition=self, conditions=orig_conds) self.needs_update_scheduler = False return self._scheduler @@ -4111,37 +4112,67 @@ def add_nodes(self, nodes, required_roles=None, context=None): f"({node}) must be a {Mechanism.__name__}, {Composition.__name__}, " f"or a tuple containing one of those and a {NodeRole.__name__} or list of them") + def remove_node(self, node): + self._remove_node(node) + + def _remove_node(self, node, analyze_graph=True): + for proj in node.afferents + node.efferents: + self.remove_projection(proj) + + for param_port in node.parameter_ports: + for proj in param_port.mod_afferents: + self.remove_projection(proj) + + # deactivate any shadowed projections + for shadow_target, shadow_port_original in self.shadowing_dict.items(): + if shadow_port_original in node.input_ports: + for shadow_proj in shadow_target.all_afferents: + if shadow_proj.sender.owner.composition is self: + self.remove_projection(shadow_proj) + + # NOTE: deactivation should be sufficient but + # asserts in OCM _update_state_input_port_names + # need target input ports of shadowed + # projections to be active or not present at all + try: + self.controller.state_input_ports.remove(shadow_target) + except AttributeError: + pass + + self.graph.remove_component(node) + del self.nodes_to_roles[node] + + # Remove any entries for node in required_node_roles or excluded_node_roles + node_role_pairs = [item for item in self.required_node_roles if item[0] is node] + for item in node_role_pairs: + self.required_node_roles.remove(item) + node_role_pairs = [item for item in self.excluded_node_roles if item[0] is node] + for item in node_role_pairs: + self.excluded_node_roles.remove(item) + + del self.nodes[node] + self.node_ordering.remove(node) + + for p in self.pathways: + try: + p.pathway.remove(node) + except ValueError: + pass + + self.needs_update_graph_processing = True + self.needs_update_scheduler = True + + if analyze_graph: + self._analyze_graph() + def remove_nodes(self, nodes): if not isinstance(nodes, (list, Mechanism, Composition)): assert False, 'Argument of remove_nodes must be a Mechanism, Composition or list containing either or both' nodes = convert_to_list(nodes) for node in nodes: - for proj in node.afferents + node.efferents: - try: - del self.projections[proj] - except ValueError: - # why are these not present? - pass - - try: - self.graph.remove_component(proj) - except CompositionError: - # why are these not present? - pass - - self.graph.remove_component(node) - del self.nodes_to_roles[node] + self._remove_node(node, analyze_graph=False) - # Remove any entries for node in required_node_roles or excluded_node_roles - node_role_pairs = [item for item in self.required_node_roles if item[0] is node] - for item in node_role_pairs: - self.required_node_roles.remove(item) - node_role_pairs = [item for item in self.excluded_node_roles if item[0] is node] - for item in node_role_pairs: - self.excluded_node_roles.remove(item) - - del self.nodes[node] - self.node_ordering.remove(node) + self._analyze_graph() @handle_external_context() def _add_required_node_role(self, node, role, context=None): @@ -5799,7 +5830,32 @@ def remove_projection(self, projection): if projection in self.projections: self.projections.remove(projection) - # step 3 - TBI? remove Projection from afferents & efferents lists of any node + # step 3 - deactivate Projection in this Composition + projection._deactivate_for_compositions(self) + + # step 4 - deactivate any learning to this Projection + for param_port in projection.parameter_ports: + for proj in param_port.mod_afferents: + self.remove_projection(proj) + if isinstance(proj.sender.owner, LearningMechanism): + for path in self.pathways: + # TODO: make learning_components values consistent type + try: + learning_mechs = path.learning_components['LEARNING_MECHANISMS'] + except KeyError: + continue + + if isinstance(learning_mechs, LearningMechanism): + learning_mechs = [learning_mechs] + + if proj.sender.owner in learning_mechs: + for mech in learning_mechs: + self.remove_node(mech) + self.remove_node(path.learning_components['objective_mechanism']) + self.remove_node(path.learning_components['TARGET_MECHANISM']) + + # step 5 - TBI? remove Projection from afferents & efferents lists of any node + def _validate_projection(self, projection, diff --git a/psyneulink/core/globals/socket.py b/psyneulink/core/globals/socket.py index e995466f98c..54e93c44a8d 100644 --- a/psyneulink/core/globals/socket.py +++ b/psyneulink/core/globals/socket.py @@ -34,6 +34,15 @@ def add_composition(self, composition): else: self.compositions.add(composition) + def remove_composition(self, composition): + if composition is self.ALL: + self.compositions = set() + else: + try: + self.compositions.remove(composition) + except (AttributeError, KeyError): + logger.info('Attempted to remove composition from {} but was not active'.format(self)) + def is_active_in_composition(self, composition): if self.compositions is None: return False diff --git a/psyneulink/core/scheduling/scheduler.py b/psyneulink/core/scheduling/scheduler.py index ce265986788..f163dc19900 100644 --- a/psyneulink/core/scheduling/scheduler.py +++ b/psyneulink/core/scheduling/scheduler.py @@ -7,6 +7,7 @@ # ********************************************* Scheduler ************************************************************** +import copy import typing import graph_scheduler @@ -48,6 +49,9 @@ def __init__( if default_execution_id is None: default_execution_id = composition.default_execution_id + # TODO: consider integrating something like this into graph-scheduler? + self._user_specified_conds = copy.copy(conditions) + super().__init__( graph=graph, conditions=conditions, diff --git a/tests/composition/test_composition.py b/tests/composition/test_composition.py index d201079f20f..959518a67f6 100644 --- a/tests/composition/test_composition.py +++ b/tests/composition/test_composition.py @@ -1,3 +1,4 @@ +import collections import functools import logging from timeit import timeit @@ -18,6 +19,7 @@ from psyneulink.core.components.mechanisms.modulatory.control.optimizationcontrolmechanism import \ OptimizationControlMechanism from psyneulink.core.components.mechanisms.modulatory.learning.learningmechanism import LearningMechanism +from psyneulink.core.components.mechanisms.processing.compositioninterfacemechanism import CompositionInterfaceMechanism from psyneulink.core.components.mechanisms.processing.integratormechanism import IntegratorMechanism from psyneulink.core.components.mechanisms.processing.objectivemechanism import ObjectiveMechanism from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism @@ -7167,6 +7169,147 @@ def test_danglingControlledMech(self): comp.add_node(Reward) # no assert, should only complete without error + @pytest.mark.parametrize( + 'removed_nodes, expected_dependencies', + [ + (['A'], {'B': set(), 'C': set('B'), 'D': set('C'), 'E': set('C')}), + (['C'], {'A': set(), 'B': set(), 'D': set(), 'E': set()}), + (['E'], {'A': set(), 'B': set(), 'C': {'A', 'B'}, 'D': set('C')}), + (['A', 'B'], {'C': set(), 'D': set('C'), 'E': set('C')}), + (['D', 'E'], {'A': set(), 'B': set(), 'C': {'A', 'B'}}), + (['A', 'B', 'C', 'D', 'E'], {}), + ] + ) + def test_remove_node(self, removed_nodes, expected_dependencies): + def stringify_dependency_dict(dd): + return {node.name: {n.name for n in deps} for node, deps in dd.items()} + + A = pnl.TransferMechanism(name='A') + B = pnl.TransferMechanism(name='B') + C = pnl.TransferMechanism(name='C') + D = pnl.TransferMechanism(name='D') + E = pnl.TransferMechanism(name='E') + + locs = locals() + removed_nodes = [locs[n] for n in removed_nodes] + + comp = pnl.Composition( + pathways=[ + [A, C, D], + [A, C, D], + [B, C, D], + [B, C, E], + ] + ) + + comp.remove_nodes(removed_nodes) + + assert stringify_dependency_dict(comp.scheduler.dependency_dict) == expected_dependencies + assert stringify_dependency_dict(comp.graph_processing.dependency_dict) == expected_dependencies + + proj_dependencies = collections.defaultdict(set) + for node in comp.nodes: + if node not in removed_nodes: + for proj in node.afferents: + if ( + proj.sender.owner not in removed_nodes + and proj.receiver.owner not in removed_nodes + and not isinstance(proj.sender.owner, CompositionInterfaceMechanism) + and not isinstance(proj.receiver.owner, CompositionInterfaceMechanism) + ): + proj_dependencies[node.name].add(proj.name) + proj_dependencies[proj.name].add(proj.sender.owner.name) + assert stringify_dependency_dict(comp.graph.dependency_dict) == {**expected_dependencies, **proj_dependencies} + + for node in removed_nodes: + assert node not in comp.nodes + assert node not in comp.nodes_to_roles + assert node not in comp.graph.comp_to_vertex + assert node not in comp.graph_processing.comp_to_vertex + assert node not in comp.scheduler.conditions + + for proj in node.afferents + node.efferents: + assert proj not in comp.projections + assert not proj.is_active_in_composition(comp) + + comp.run(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)}) + + @pytest.mark.parametrize('slope_A', [1, (1, pnl.CONTROL)]) + @pytest.mark.parametrize('slope_B', [1, (1, pnl.CONTROL)]) + @pytest.mark.parametrize('removed_nodes', [['A'], ['B'], ['A', 'B']]) + def test_remove_node_control(self, slope_A, slope_B, removed_nodes): + A = pnl.TransferMechanism(name='A', function=pnl.Linear(slope=slope_A)) + B = pnl.TransferMechanism(name='B', function=pnl.Linear(slope=slope_B)) + + locs = locals() + removed_nodes = [locs[n] for n in removed_nodes] + + search_space_len = sum(1 if isinstance(s, tuple) else 0 for s in [slope_A, slope_B]) + + comp = pnl.Composition(pathways=[A, B]) + comp.add_controller( + pnl.OptimizationControlMechanism( + agent_rep=comp, search_space=[0, 1] * search_space_len + ) + ) + comp.remove_nodes(removed_nodes) + + for n in removed_nodes: + for proj in n.parameter_ports['slope'].all_afferents: + assert not proj.is_active_in_composition(comp), f'{n.name} {proj.name}' + + def test_remove_node_from_conditions(self): + def assert_conditions_do_not_contain(*args): + conds_queue = list(comp.scheduler.conditions.conditions.values()) + while len(conds_queue) > 0: + cur_cond = conds_queue.pop() + deps = [] + + try: + deps = [cur_cond.dependency] + except AttributeError: + try: + deps = cur_cond.dependencies + except AttributeError: + pass + + for d in deps: + if isinstance(d, pnl.Condition): + conds_queue.append(d) + assert d not in args + + A = pnl.TransferMechanism(name='A') + B = pnl.TransferMechanism(name='B') + C = pnl.TransferMechanism(name='C') + D = pnl.TransferMechanism(name='D') + + comp = pnl.Composition(pathways=[[A, D], [B, D], [C, D]]) + + comp.run(inputs={A: [0], B: [0], C: [0]}) + comp.remove_node(A) + comp.run(inputs={B: [0], C: [0]}) + assert_conditions_do_not_contain(A) + + comp.remove_node(B) + comp.run(inputs={C: [0]}) + assert_conditions_do_not_contain(A, B) + + def test_remove_node_learning(self): + A = ProcessingMechanism(name='A') + B = ProcessingMechanism(name='B') + C = ProcessingMechanism(name='C') + D = ProcessingMechanism(name='D') + + comp = Composition() + comp.add_linear_learning_pathway(pathway=[A, B], learning_function=BackPropagation) + comp.add_linear_learning_pathway(pathway=[C, D], learning_function=Reinforcement) + + comp.remove_node(A) + comp.learn(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)}) + + comp.remove_node(D) + comp.learn(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)}) + class TestInputSpecsDocumentationExamples: