Skip to content

Commit

Permalink
Merge pull request #2364 from kmantel/removenode
Browse files Browse the repository at this point in the history
Add node removal to Compositions
  • Loading branch information
kmantel authored Mar 29, 2022
2 parents d27dc89 + 7dd6a50 commit 3d76263
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 29 deletions.
12 changes: 12 additions & 0 deletions psyneulink/core/components/projections/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,18 @@ def _activate_for_compositions(self, composition):
def _activate_for_all_compositions(self):
self._activate_for_compositions(ConnectionInfo.ALL)

def _deactivate_for_compositions(self, composition):
try:
self.receiver.afferents_info[self].remove_composition(composition)
except KeyError:
warnings.warn(f'{self} was not active for {composition}')

def _deactivate_for_all_compositions(self):
self._deactivate_for_all_compositions(ConnectionInfo.ALL)

def is_active_in_composition(self, composition):
return self.receiver.afferents_info[self].is_active_in_composition(composition)

def _delete_projection(projection, context=None):
"""Delete Projection, its entries in receiver and sender Ports, and in ProjectionRegistry"""
projection.sender._remove_projection_from_port(projection)
Expand Down
114 changes: 85 additions & 29 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3880,11 +3880,12 @@ def scheduler(self):
"""
if self.needs_update_scheduler or not isinstance(self._scheduler, Scheduler):
old_scheduler = self._scheduler
self._scheduler = Scheduler(composition=self)

if old_scheduler is not None:
self._scheduler.add_condition_set(old_scheduler.conditions)
orig_conds = old_scheduler._user_specified_conds
else:
orig_conds = None

self._scheduler = Scheduler(composition=self, conditions=orig_conds)
self.needs_update_scheduler = False

return self._scheduler
Expand Down Expand Up @@ -4111,37 +4112,67 @@ def add_nodes(self, nodes, required_roles=None, context=None):
f"({node}) must be a {Mechanism.__name__}, {Composition.__name__}, "
f"or a tuple containing one of those and a {NodeRole.__name__} or list of them")

def remove_node(self, node):
self._remove_node(node)

def _remove_node(self, node, analyze_graph=True):
for proj in node.afferents + node.efferents:
self.remove_projection(proj)

for param_port in node.parameter_ports:
for proj in param_port.mod_afferents:
self.remove_projection(proj)

# deactivate any shadowed projections
for shadow_target, shadow_port_original in self.shadowing_dict.items():
if shadow_port_original in node.input_ports:
for shadow_proj in shadow_target.all_afferents:
if shadow_proj.sender.owner.composition is self:
self.remove_projection(shadow_proj)

# NOTE: deactivation should be sufficient but
# asserts in OCM _update_state_input_port_names
# need target input ports of shadowed
# projections to be active or not present at all
try:
self.controller.state_input_ports.remove(shadow_target)
except AttributeError:
pass

self.graph.remove_component(node)
del self.nodes_to_roles[node]

# Remove any entries for node in required_node_roles or excluded_node_roles
node_role_pairs = [item for item in self.required_node_roles if item[0] is node]
for item in node_role_pairs:
self.required_node_roles.remove(item)
node_role_pairs = [item for item in self.excluded_node_roles if item[0] is node]
for item in node_role_pairs:
self.excluded_node_roles.remove(item)

del self.nodes[node]
self.node_ordering.remove(node)

for p in self.pathways:
try:
p.pathway.remove(node)
except ValueError:
pass

self.needs_update_graph_processing = True
self.needs_update_scheduler = True

if analyze_graph:
self._analyze_graph()

def remove_nodes(self, nodes):
if not isinstance(nodes, (list, Mechanism, Composition)):
assert False, 'Argument of remove_nodes must be a Mechanism, Composition or list containing either or both'
nodes = convert_to_list(nodes)
for node in nodes:
for proj in node.afferents + node.efferents:
try:
del self.projections[proj]
except ValueError:
# why are these not present?
pass

try:
self.graph.remove_component(proj)
except CompositionError:
# why are these not present?
pass

self.graph.remove_component(node)
del self.nodes_to_roles[node]
self._remove_node(node, analyze_graph=False)

# Remove any entries for node in required_node_roles or excluded_node_roles
node_role_pairs = [item for item in self.required_node_roles if item[0] is node]
for item in node_role_pairs:
self.required_node_roles.remove(item)
node_role_pairs = [item for item in self.excluded_node_roles if item[0] is node]
for item in node_role_pairs:
self.excluded_node_roles.remove(item)

del self.nodes[node]
self.node_ordering.remove(node)
self._analyze_graph()

@handle_external_context()
def _add_required_node_role(self, node, role, context=None):
Expand Down Expand Up @@ -5799,7 +5830,32 @@ def remove_projection(self, projection):
if projection in self.projections:
self.projections.remove(projection)

# step 3 - TBI? remove Projection from afferents & efferents lists of any node
# step 3 - deactivate Projection in this Composition
projection._deactivate_for_compositions(self)

# step 4 - deactivate any learning to this Projection
for param_port in projection.parameter_ports:
for proj in param_port.mod_afferents:
self.remove_projection(proj)
if isinstance(proj.sender.owner, LearningMechanism):
for path in self.pathways:
# TODO: make learning_components values consistent type
try:
learning_mechs = path.learning_components['LEARNING_MECHANISMS']
except KeyError:
continue

if isinstance(learning_mechs, LearningMechanism):
learning_mechs = [learning_mechs]

if proj.sender.owner in learning_mechs:
for mech in learning_mechs:
self.remove_node(mech)
self.remove_node(path.learning_components['objective_mechanism'])
self.remove_node(path.learning_components['TARGET_MECHANISM'])

# step 5 - TBI? remove Projection from afferents & efferents lists of any node


def _validate_projection(self,
projection,
Expand Down
9 changes: 9 additions & 0 deletions psyneulink/core/globals/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def add_composition(self, composition):
else:
self.compositions.add(composition)

def remove_composition(self, composition):
if composition is self.ALL:
self.compositions = set()
else:
try:
self.compositions.remove(composition)
except (AttributeError, KeyError):
logger.info('Attempted to remove composition from {} but was not active'.format(self))

def is_active_in_composition(self, composition):
if self.compositions is None:
return False
Expand Down
4 changes: 4 additions & 0 deletions psyneulink/core/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


# ********************************************* Scheduler **************************************************************
import copy
import typing

import graph_scheduler
Expand Down Expand Up @@ -48,6 +49,9 @@ def __init__(
if default_execution_id is None:
default_execution_id = composition.default_execution_id

# TODO: consider integrating something like this into graph-scheduler?
self._user_specified_conds = copy.copy(conditions)

super().__init__(
graph=graph,
conditions=conditions,
Expand Down
143 changes: 143 additions & 0 deletions tests/composition/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import functools
import logging
from timeit import timeit
Expand All @@ -18,6 +19,7 @@
from psyneulink.core.components.mechanisms.modulatory.control.optimizationcontrolmechanism import \
OptimizationControlMechanism
from psyneulink.core.components.mechanisms.modulatory.learning.learningmechanism import LearningMechanism
from psyneulink.core.components.mechanisms.processing.compositioninterfacemechanism import CompositionInterfaceMechanism
from psyneulink.core.components.mechanisms.processing.integratormechanism import IntegratorMechanism
from psyneulink.core.components.mechanisms.processing.objectivemechanism import ObjectiveMechanism
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism
Expand Down Expand Up @@ -7167,6 +7169,147 @@ def test_danglingControlledMech(self):
comp.add_node(Reward)
# no assert, should only complete without error

@pytest.mark.parametrize(
'removed_nodes, expected_dependencies',
[
(['A'], {'B': set(), 'C': set('B'), 'D': set('C'), 'E': set('C')}),
(['C'], {'A': set(), 'B': set(), 'D': set(), 'E': set()}),
(['E'], {'A': set(), 'B': set(), 'C': {'A', 'B'}, 'D': set('C')}),
(['A', 'B'], {'C': set(), 'D': set('C'), 'E': set('C')}),
(['D', 'E'], {'A': set(), 'B': set(), 'C': {'A', 'B'}}),
(['A', 'B', 'C', 'D', 'E'], {}),
]
)
def test_remove_node(self, removed_nodes, expected_dependencies):
def stringify_dependency_dict(dd):
return {node.name: {n.name for n in deps} for node, deps in dd.items()}

A = pnl.TransferMechanism(name='A')
B = pnl.TransferMechanism(name='B')
C = pnl.TransferMechanism(name='C')
D = pnl.TransferMechanism(name='D')
E = pnl.TransferMechanism(name='E')

locs = locals()
removed_nodes = [locs[n] for n in removed_nodes]

comp = pnl.Composition(
pathways=[
[A, C, D],
[A, C, D],
[B, C, D],
[B, C, E],
]
)

comp.remove_nodes(removed_nodes)

assert stringify_dependency_dict(comp.scheduler.dependency_dict) == expected_dependencies
assert stringify_dependency_dict(comp.graph_processing.dependency_dict) == expected_dependencies

proj_dependencies = collections.defaultdict(set)
for node in comp.nodes:
if node not in removed_nodes:
for proj in node.afferents:
if (
proj.sender.owner not in removed_nodes
and proj.receiver.owner not in removed_nodes
and not isinstance(proj.sender.owner, CompositionInterfaceMechanism)
and not isinstance(proj.receiver.owner, CompositionInterfaceMechanism)
):
proj_dependencies[node.name].add(proj.name)
proj_dependencies[proj.name].add(proj.sender.owner.name)
assert stringify_dependency_dict(comp.graph.dependency_dict) == {**expected_dependencies, **proj_dependencies}

for node in removed_nodes:
assert node not in comp.nodes
assert node not in comp.nodes_to_roles
assert node not in comp.graph.comp_to_vertex
assert node not in comp.graph_processing.comp_to_vertex
assert node not in comp.scheduler.conditions

for proj in node.afferents + node.efferents:
assert proj not in comp.projections
assert not proj.is_active_in_composition(comp)

comp.run(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)})

@pytest.mark.parametrize('slope_A', [1, (1, pnl.CONTROL)])
@pytest.mark.parametrize('slope_B', [1, (1, pnl.CONTROL)])
@pytest.mark.parametrize('removed_nodes', [['A'], ['B'], ['A', 'B']])
def test_remove_node_control(self, slope_A, slope_B, removed_nodes):
A = pnl.TransferMechanism(name='A', function=pnl.Linear(slope=slope_A))
B = pnl.TransferMechanism(name='B', function=pnl.Linear(slope=slope_B))

locs = locals()
removed_nodes = [locs[n] for n in removed_nodes]

search_space_len = sum(1 if isinstance(s, tuple) else 0 for s in [slope_A, slope_B])

comp = pnl.Composition(pathways=[A, B])
comp.add_controller(
pnl.OptimizationControlMechanism(
agent_rep=comp, search_space=[0, 1] * search_space_len
)
)
comp.remove_nodes(removed_nodes)

for n in removed_nodes:
for proj in n.parameter_ports['slope'].all_afferents:
assert not proj.is_active_in_composition(comp), f'{n.name} {proj.name}'

def test_remove_node_from_conditions(self):
def assert_conditions_do_not_contain(*args):
conds_queue = list(comp.scheduler.conditions.conditions.values())
while len(conds_queue) > 0:
cur_cond = conds_queue.pop()
deps = []

try:
deps = [cur_cond.dependency]
except AttributeError:
try:
deps = cur_cond.dependencies
except AttributeError:
pass

for d in deps:
if isinstance(d, pnl.Condition):
conds_queue.append(d)
assert d not in args

A = pnl.TransferMechanism(name='A')
B = pnl.TransferMechanism(name='B')
C = pnl.TransferMechanism(name='C')
D = pnl.TransferMechanism(name='D')

comp = pnl.Composition(pathways=[[A, D], [B, D], [C, D]])

comp.run(inputs={A: [0], B: [0], C: [0]})
comp.remove_node(A)
comp.run(inputs={B: [0], C: [0]})
assert_conditions_do_not_contain(A)

comp.remove_node(B)
comp.run(inputs={C: [0]})
assert_conditions_do_not_contain(A, B)

def test_remove_node_learning(self):
A = ProcessingMechanism(name='A')
B = ProcessingMechanism(name='B')
C = ProcessingMechanism(name='C')
D = ProcessingMechanism(name='D')

comp = Composition()
comp.add_linear_learning_pathway(pathway=[A, B], learning_function=BackPropagation)
comp.add_linear_learning_pathway(pathway=[C, D], learning_function=Reinforcement)

comp.remove_node(A)
comp.learn(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)})

comp.remove_node(D)
comp.learn(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)})


class TestInputSpecsDocumentationExamples:

Expand Down

0 comments on commit 3d76263

Please sign in to comment.