Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/learning nested #2801

Merged
merged 88 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 85 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
393c31a
[skip ci]
jdcpni Sep 1, 2023
00ab289
[skip ci]
jdcpni Sep 1, 2023
253114a
[skip ci]
jdcpni Sep 1, 2023
410bdb3
[skip ci]
Sep 1, 2023
48ecc4e
[skip ci]
Sep 1, 2023
f86aa13
[skip ci]
Sep 2, 2023
d255db8
[skip ci]
Sep 2, 2023
63deff6
[skip ci]
Sep 2, 2023
e68d3cb
[skip ci]
Sep 2, 2023
97289fc
[skip ci]
jdcpni Sep 2, 2023
a42de9e
[skip ci]
jdcpni Sep 2, 2023
52d0b8a
[skip ci]
jdcpni Sep 2, 2023
157e125
[skip ci]
Sep 2, 2023
e41cb3c
[skip ci]
Sep 2, 2023
5340782
[skip ci]
jdcpni Sep 2, 2023
dab69e4
[skip ci]
jdcpni Sep 2, 2023
5d91b2c
[skip ci]
Sep 2, 2023
9aa1727
[skip ci]
jdcpni Sep 3, 2023
21d3a55
[skip ci]
jdcpni Sep 3, 2023
7825bff
[skip ci]
Sep 3, 2023
45ef790
[skip ci]
Sep 3, 2023
80b7537
[skip ci]
Sep 3, 2023
ea8484d
[skip ci]
Sep 3, 2023
523d80c
[skip ci]
Sep 4, 2023
a21f1e7
[skip ci]
jdcpni Sep 4, 2023
4e0ef0d
[skip ci]
jdcpni Sep 4, 2023
35d3229
[skip ci]
jdcpni Sep 4, 2023
1c6ec16
[skip ci]
jdcpni Sep 4, 2023
317d55b
[skip ci]
jdcpni Sep 4, 2023
b1e21b5
[skip ci]
jdcpni Sep 4, 2023
30f5425
[skip ci]
Sep 4, 2023
93ce161
[skip ci]
Sep 4, 2023
b6addbe
[skip ci]
Sep 4, 2023
dae137f
[skip ci]
jdcpni Sep 5, 2023
43d9bbd
[skip ci]
jdcpni Sep 5, 2023
6477acd
[skip ci]
jdcpni Sep 5, 2023
c8459be
Merge branch 'feat/learning_nested' of https://github.com/PrincetonUn…
jdcpni Sep 5, 2023
0ea667f
[skip ci]
jdcpni Sep 5, 2023
89bb55c
[skip ci]
jdcpni Sep 5, 2023
d7c0b42
[skip ci]
jdcpni Sep 6, 2023
289eaf8
[skip ci]
jdcpni Sep 6, 2023
b3d9c80
[skip ci]
jdcpni Sep 7, 2023
101be90
[skip ci]
Sep 7, 2023
efe06c3
[skip ci]
Sep 7, 2023
02d8b34
[skip ci]
Sep 7, 2023
8417f82
[skip ci]
Sep 7, 2023
a3853ca
[skip ci]
Sep 8, 2023
e34fc47
[skip ci]
Sep 8, 2023
8b4fb68
[skip ci]
jdcpni Sep 8, 2023
3ed4f5d
[skip ci]
jdcpni Sep 8, 2023
32dcaac
[skip ci]
jdcpni Sep 8, 2023
e39466d
[skip ci]
Sep 8, 2023
b91d12d
Merge branch 'feat/learning_nested' of https://github.com/PrincetonUn…
Sep 8, 2023
896c662
[skip ci]
Sep 8, 2023
d31d232
[skip ci]
Sep 8, 2023
277c43c
[skip ci]
Sep 9, 2023
98351c2
[skip ci]
Sep 10, 2023
58e3e16
[skip ci]
Sep 10, 2023
ade87d5
[skip ci]
jdcpni Sep 12, 2023
3008de8
[skip ci]
jdcpni Sep 12, 2023
0051641
[skip ci]
jdcpni Sep 12, 2023
f354e7b
[skip ci]
jdcpni Sep 12, 2023
52fb20c
[skip ci]
jdcpni Sep 13, 2023
bcd3e48
Merge branch 'devel' of https://github.com/PrincetonUniversity/PsyNeu…
jdcpni Sep 13, 2023
0e876f3
[skip ci]
jdcpni Sep 13, 2023
d9864ac
[skip ci]
jdcpni Sep 13, 2023
64de176
[skip ci]
jdcpni Sep 16, 2023
b5f0b7d
Merge remote-tracking branch 'upstream/feat/learning_nested' into fea…
jdcpni Sep 16, 2023
8d74edd
[skip ci]
jdcpni Sep 16, 2023
015f86e
[skip ci]
jdcpni Sep 16, 2023
5bceec1
[skip ci]
jdcpni Sep 16, 2023
b59bc02
[skip ci]
jdcpni Sep 16, 2023
73d6dac
[skip ci]
jdcpni Sep 16, 2023
b66ec11
[skip ci]
Sep 17, 2023
67c4064
Merge branch 'feat/learning_nested' of https://github.com/PrincetonUn…
Sep 17, 2023
749ff5d
[skip ci]
Sep 17, 2023
ee95470
[skip ci]
Sep 17, 2023
e7fba45
[skip ci]
jdcpni Sep 17, 2023
1ac34e5
[skip ci]
jdcpni Sep 17, 2023
a10ce61
[skip ci]
jdcpni Sep 17, 2023
e6c7952
[skip ci]
jdcpni Sep 18, 2023
396bf0b
[skip ci]
jdcpni Sep 18, 2023
4a07829
[skip ci]
jdcpni Sep 18, 2023
dfa182d
[skip ci]
jdcpni Sep 18, 2023
b6e5faf
-
jdcpni Sep 18, 2023
5a7f5ba
• autodiffcomposition.py
jdcpni Sep 18, 2023
93f7e6e
[skip ci]
jdcpni Sep 18, 2023
931dfac
Merge branch 'feat/learning_nested' of https://github.com/PrincetonUn…
jdcpni Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Scripts/Models (Under Development)/EGO/EGO Model - MDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

# FIX: TERMINATION CONDITION IS GETTING TRIGGED AFTER 1st TRIAL

# FOR INPUT NODES: scheduler.add_condition(A, BeforeNCalls(A,1)
# Termination: AfterNCalls(Ctl,2)

"""
QUESTIONS:

Expand Down Expand Up @@ -105,6 +108,7 @@
Use of SweetPea for stimulus generation requires it be installed::
>> pip install sweetpea


.. _EGO_training:

*Training*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1256,12 +1256,16 @@ def _validate_variable(self, variable, context=None):
old_length = 1
else:
old_length = len(variable[i - 1])
if isinstance(variable[i], numbers.Number):
if variable[i] is None:
owner_str = f"'{self.owner.name}' " if self.owner else ''
raise FunctionError(f"One of the elements of variable for {self.__class__.__name__} function "
f"of {owner_str}is None; variable: {variable}.")
elif isinstance(variable[i], numbers.Number):
new_length = 1
else:
new_length = len(variable[i])
if old_length != new_length:
owner_str = f"'{self.owner.name }'" if self.owner else ''
owner_str = f"'{self.owner.name }' " if self.owner else ''
raise FunctionError(f"Length of all arrays in variable for {self.__class__.__name__} function "
f"of {owner_str}must be the same; variable: {variable}.")
return variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2480,8 +2480,11 @@ def _function(self,
self.parameters.error_matrix._set(error_matrix, context)
# self._check_args(variable=variable, context=context, params=params, context=context)

# If learning_rate was not specified for instance or composition, use default value
learning_rate = self._get_current_parameter_value(LEARNING_RATE, context)
# If learning_rate was not specified for instance or composition or in params, use default value
if params and LEARNING_RATE in params and params[LEARNING_RATE] is not None:
learning_rate = params[LEARNING_RATE]
else:
learning_rate = self._get_current_parameter_value(LEARNING_RATE, context)
if learning_rate is None:
learning_rate = self.defaults.learning_rate

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1154,9 +1154,17 @@ def _check_type_and_timing(self):
repr(LEARNING_TIMING)))

def _parse_function_variable(self, variable, context=None):
# # MODIFIED 9/11/23 OLD:
# Return values of ACTIVATION_INPUT_INDEX, ACTIVATION_OUTPUT_INDEX, and first ERROR_SIGNAL_INDEX InputPorts
# in variable; remaining inputs (additional error signals and/or COVARITES) are passed in kwargs)
return variable[range(min(len(self.input_ports),3))]
# MODIFIED 9/11/23 NEW:
# Return values of ACTIVATION_INPUT, ACTIVATION_OUTPUT, and ERROR_SIGNAL InputPorts in variable;
# remaining inputs (additional error signals and/or COVARITES) are passed in kwargs)
# return np.array(self.input_values, dtype=object)
# FIX: SHOULD EXTRA ERROR_SIGNAL (AND ERROR_MATRIX) BE PUT IN params? CF WHAT HAPPENS WITH RUMELHART NETWORK
# return variable[range(2 + len(self.error_signal_input_ports))]
# MODIFIED 9/11/23 END

def _validate_variable(self, variable, context=None):
"""Validate that variable has exactly three items: activation_input, activation_output and error_signal
Expand Down Expand Up @@ -1395,10 +1403,11 @@ def add_ports(self, error_sources, context=None):
error_source = input_port.path_afferents[0].sender.owner
self.error_matrices.append(error_source.primary_learned_projection.parameter_ports[MATRIX])
if ERROR_SIGNAL in input_port.name:
# self._error_signal_input_ports.append(input_port)
self.error_signal_input_ports.append(input_port)
instantiated_input_ports.append(input_port)

assert True

# TODO: enable this. fails because LearningMechanism does not have a
# consistent _parse_function_variable
# self._update_default_variable(np.asarray(self.input_values, dtype=int), context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,13 @@ def _get_destination_info_from_input_CIM(self, port, comp=None):
comp = comp or self.composition
port_map = port.owner.port_map
idx = 0 if isinstance(port, InputPort) else 1
output_port = [port_map[k][1] for k in port_map if port_map[k][idx] is port]
assert len(output_port)==1, f"PROGRAM ERROR: Expected exactly 1 output_port for {port.name} " \
f"in port_map for {port.owner}; found {len(output_port)}."
assert len(output_port[0].efferents)==1, f"PROGRAM ERROR: Port ({output_port.name}) expected to have " \
f"just one efferent; has {len(output_port.efferents)}."
receiver = output_port[0].efferents[0].receiver
output_ports = [port_map[k][1] for k in port_map if port_map[k][idx] is port]
assert len(output_ports)==1, f"PROGRAM ERROR: Expected exactly 1 output_port for {port.name} " \
f"in port_map for {port.owner}; found {len(output_ports)}."
output_port = output_ports[0]
assert len(output_port.efferents)==1, f"PROGRAM ERROR: Port ({output_port.name}) expected to have " \
f"just one efferent; has {len(output_port.efferents)}."
receiver = output_port.efferents[0].receiver
if not isinstance(receiver.owner, CompositionInterfaceMechanism):
return receiver, receiver.owner, comp
return self._get_destination_info_from_input_CIM(receiver, receiver.owner.composition)
Expand Down Expand Up @@ -347,16 +348,53 @@ def _get_source_info_from_output_CIM(self, port, comp=None):
comp = comp or self.composition
port_map = port.owner.port_map
idx = 0 if isinstance(port, InputPort) else 1
input_port = [port_map[k][0] for k in port_map if port_map[k][idx] is port]
assert len(input_port)==1, f"PROGRAM ERROR: Expected exactly 1 input_port for {port.name} " \
f"in port_map for {port.owner}; found {len(input_port)}."
assert len(input_port[0].path_afferents)==1, f"PROGRAM ERROR: Port ({input_port.name}) expected to have " \
f"just one path_afferent; has {len(input_port.path_afferents)}."
sender = input_port[0].path_afferents[0].sender
input_ports = [port_map[k][0] for k in port_map if port_map[k][idx] is port]
assert len(input_ports)==1, f"PROGRAM ERROR: Expected exactly 1 input_port for {port.name} " \
f"in port_map for {port.owner}; found {len(input_ports)}."
assert len(input_ports[0].path_afferents)==1, f"PROGRAM ERROR: Port ({input_ports[0].name}) expected to have " \
f"just one path_afferent; has {len(input_ports.path_afferents)}."
sender = input_ports[0].path_afferents[0].sender
if not isinstance(sender.owner, CompositionInterfaceMechanism):
return sender, sender.owner, comp
return self._get_source_info_from_output_CIM(sender, sender.owner.composition)

def _get_destination_info_for_output_CIM(self, port, comp=None)-> list:
"""Return Port, Node and Composition for "ultimate" destination(s) of projection to **port**.
**port**: InputPort or OutputPort of the output_CIM to which the projection of interest projects;
used to find source (key=SENDER PORT) of the projection to the output_CIM.
**comp**: Composition at which to begin the search (or continue it when called recursively);
assumes the Composition for the output_CIM to which **port** belongs by default
If there is more than one destination, return list of tuples, one for each destination;
this occurs if the source of the projection to the output_CIM (SENDER PORT) is a Node in a nested Composition
that is specified to project to more than one Node in the outer Composition
"""
from psyneulink.core.compositions.composition import get_composition_for_node

# Ensure method is being called on an output_CIM
assert self == self.composition.output_CIM
# CIM MAP ENTRIES: [SENDER PORT, [output_CIM InputPort, output_CIM OutputPort]]
# Get receiver of output_port of output_CIM
comp = comp or self.composition
port_map = port.owner.port_map
idx = 0 if isinstance(port, InputPort) else 1
output_ports = [port_map[k][1] for k in port_map if port_map[k][idx] is port]
assert len(output_ports)==1, f"PROGRAM ERROR: Expected exactly 1 output_port for {port.name} " \
f"in port_map for {port.owner}; found {len(output_ports)}."
output_port = output_ports[0]
receivers_info = []
if not output_port.efferents:
return None
for efferent in output_port.efferents:
receiver = efferent.receiver
if not isinstance(efferent.receiver.owner, CompositionInterfaceMechanism):
assert comp.is_nested
receiver_comp = get_composition_for_node(receiver.owner)
receivers_info.append((efferent.receiver, efferent.receiver.owner, receiver_comp))
else:
receivers_info.append(self._get_destination_info_for_output_CIM(efferent.receiver,
efferent.receiver.owner.composition))
return receivers_info

def _sender_is_probe(self, output_port):
"""Return True if source of output_port is a PROBE Node of the Composition to which it belongs"""
from psyneulink.core.compositions.composition import NodeRole
Expand Down
Loading
Loading