Skip to content

Commit

Permalink
Composition: remove_projection: disable learning pathway if applicable
Browse files Browse the repository at this point in the history
if a projection being learned is removed, disable the entire pathway
because the removal will break it
  • Loading branch information
kmantel committed Mar 29, 2022
1 parent 685af56 commit d329ed7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
23 changes: 22 additions & 1 deletion psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5819,7 +5819,28 @@ def remove_projection(self, projection):
# step 3 - deactivate Projection in this Composition
projection._deactivate_for_compositions(self)

# step 4 - TBI? remove Projection from afferents & efferents lists of any node
# step 4 - deactivate any learning to this Projection
for param_port in projection.parameter_ports:
for proj in param_port.mod_afferents:
self.remove_projection(proj)
if isinstance(proj.sender.owner, LearningMechanism):
for path in self.pathways:
# TODO: make learning_components values consistent type
try:
learning_mechs = path.learning_components['LEARNING_MECHANISMS']
except KeyError:
continue

if isinstance(learning_mechs, LearningMechanism):
learning_mechs = [learning_mechs]

if proj.sender.owner in learning_mechs:
for mech in learning_mechs:
self.remove_node(mech)
self.remove_node(path.learning_components['objective_mechanism'])
self.remove_node(path.learning_components['TARGET_MECHANISM'])

# step 5 - TBI? remove Projection from afferents & efferents lists of any node


def _validate_projection(self,
Expand Down
16 changes: 16 additions & 0 deletions tests/composition/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7293,6 +7293,22 @@ def assert_conditions_do_not_contain(*args):
comp.run(inputs={C: [0]})
assert_conditions_do_not_contain(A, B)

def test_remove_node_learning(self):
A = ProcessingMechanism(name='A')
B = ProcessingMechanism(name='B')
C = ProcessingMechanism(name='C')
D = ProcessingMechanism(name='D')

comp = Composition()
comp.add_linear_learning_pathway(pathway=[A, B], learning_function=BackPropagation)
comp.add_linear_learning_pathway(pathway=[C, D], learning_function=Reinforcement)

comp.remove_node(A)
comp.learn(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)})

comp.remove_node(D)
comp.learn(inputs={n: [0] for n in comp.get_nodes_by_role(pnl.NodeRole.INPUT)})


class TestInputSpecsDocumentationExamples:

Expand Down

0 comments on commit d329ed7

Please sign in to comment.