Skip to content

Commit

Permalink
Fix/input port spec w mapping projection (#2724)
Browse files Browse the repository at this point in the history
* • port.py and projection.py:
  - fix bug in which specification using deferred init MappingProjection to specify an InputPort failed

• test_iput_state_spec.py:
  - rename as test_input_port_spec.py
  - add test for above
  • Loading branch information
jdcpni authored Jul 9, 2023
1 parent fdab14f commit 03008f7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 39 deletions.
83 changes: 47 additions & 36 deletions psyneulink/core/components/ports/port.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,9 +1281,7 @@ def _instantiate_projections(self, projections, context=None):
ModulatorySignal: _instantiate_projections_from_port (.efferents)
"""

raise PortError("{} must implement _instantiate_projections (called for {})".
format(self.__class__.__name__,
self.name))
raise PortError("f{self.__class__.__name__} must implement _instantiate_projections (called for {self.name}).")

# FIX: MOVE TO InputPort AND ParameterPort OR...
# IMPLEMENTATION NOTE: MOVE TO COMPOSITION ONCE THAT IS IMPLEMENTED
Expand Down Expand Up @@ -1372,22 +1370,26 @@ def _instantiate_projections_to_port(self, projections, context=None):

# validate receiver
if proj_receiver is not None and proj_receiver != self:
raise PortError("Projection ({}) assigned to {} of {} already has a receiver ({})".
format(projection_type.__name__, self.name, self.owner.name, proj_receiver.name))
raise PortError(f"Projection ({projection_type.__name__}) "
f"assigned to '{self.name}' of '{self.owner.name}' "
f"already has a receiver ('{proj_receiver.owner.name}[{proj_receiver.name}]').")
projection._init_args[RECEIVER] = self


# parse/validate sender
if proj_sender:
# If the Projection already has Port as its sender,
# it must be the same as the one specified in the connection spec
if isinstance(proj_sender, Port) and proj_sender != port:
raise PortError("Projection assigned to {} of {} from {} already has a sender ({})".
format(self.name, self.owner.name, port.name, proj_sender.name))
if isinstance(proj_sender, Port):
if proj_sender == port:
sender = port
else:
raise PortError(
f"Projection assigned to '{self.name}' of '{self.owner.name}' from {port.name} "
f"already has a sender ('{proj_sender.owner.name}[{proj_sender.name}]').")
# If the Projection has a Mechanism specified as its sender:
elif isinstance(port, Port):
# Connection spec (port) is specified as a Port,
# so validate that Port belongs to Mechanism and is of the correct type
# Connection spec (port) is specified as a Port, so validate that
# Port belongs to proj_sender Mechanism and is of the correct type
sender = _get_port_for_socket(owner=self.owner,
mech=proj_sender,
port_spec=port,
Expand Down Expand Up @@ -1420,8 +1422,8 @@ def _instantiate_projections_to_port(self, projections, context=None):
elif inspect.isclass(sender) and issubclass(sender, Port):
sender_name = sender.__name__
else:
raise PortError("SENDER of {} to {} of {} is neither a Port or Port class".
format(projection_type.__name__, self.name, self.owner.name))
raise PortError(f"SENDER of {projection_type.__name__} to {self.name} of {self.owner.name} "
f"is neither a Port or Port class.")
projection._assign_default_projection_name(port=self,
sender_name=sender_name,
receiver_name=self.name)
Expand Down Expand Up @@ -2883,13 +2885,21 @@ def _parse_port_spec(port_type=None,
# If it is a Port specification dictionary
if isinstance(port_spec[PORT_SPEC_ARG], dict):

# If the Port specification is a Projection that has a sender already assigned,
# If the Port specification has a Projection that has a sender already assigned,
# then return that Port with the Projection assigned to it
# (this occurs, for example, if an instantiated ControlSignal is used to specify a parameter
# FIX: JDC 7/8/23 ??WHAT IF PORT SPECIFICATION DICT HAS OTHER SPECS, SUCH AS SIZE?
# POSSIBLY THIS SHOULD ONLY BE CALLED IF DICT CONTAINS *ONLY* A PROJECTION SPEC?
try:
assert len(port_spec[PORT_SPEC_ARG][PROJECTIONS])==1
projection = port_spec[PORT_SPEC_ARG][PROJECTIONS][0]
port = projection.sender
projection = port_spec[PORT_SPEC_ARG][PROJECTIONS]
if isinstance(projection, list):
assert len(port_spec[PORT_SPEC_ARG][PROJECTIONS])==1
projection = port_spec[PORT_SPEC_ARG][PROJECTIONS][0]
port = projection.sender
elif projection.initialization_status == ContextFlags.DEFERRED_INIT:
port = projection._init_args[SENDER]
else:
port = projection.sender
if port.initialization_status == ContextFlags.DEFERRED_INIT:
port._init_args[PROJECTIONS] = projection
else:
Expand Down Expand Up @@ -3149,9 +3159,9 @@ def _parse_port_spec(port_type=None,

if isinstance(port_specification, (list, set)):
port_specific_specs = ProjectionTuple(port=port_specification,
weight=None,
exponent=None,
projection=port_type)
weight=None,
exponent=None,
projection=port_type)

# Port specification is a tuple
elif isinstance(port_specification, tuple):
Expand Down Expand Up @@ -3221,13 +3231,8 @@ def _parse_port_spec(port_type=None,
port = port_attr[port]
except:
name = owner.name if 'unnamed' not in owner.name else 'a ' + owner.__class__.__name__
raise PortError("Unrecognized name ({}) for {} "
"of {} in specification of {} "
"for {}".format(port,
PORTS,
mech.name,
port_type.__name__,
name))
raise PortError("Unrecognized name ({port}) for {PORTS} of {mech.name} "
"in specification of {port_type.__name__} for {name}.")
# If port_spec was a tuple, put port back in as its first item and use as projection spec
if isinstance(port_spec, tuple):
port = (port,) + port_spec[1:]
Expand Down Expand Up @@ -3363,7 +3368,7 @@ def _parse_port_spec(port_type=None,
# port_dict[OWNER].name, spec_function_value, spec_function))

if port_dict[REFERENCE_VALUE] is not None and not iscompatible(port_dict[VALUE], port_dict[REFERENCE_VALUE]):
port_name = f"the {port_dict[NAME]}" if (NAME in port_dict and port_dict[NAME]) else f"an"
port_name = f"the {port_dict[NAME]}" if (NAME in port_dict and port_dict[NAME]) else f"a"
raise PortError(f"The value ({port_dict[VALUE]}) for {port_name} {port_type.__name__} of "
f"{owner.name} does not match the reference_value ({port_dict[REFERENCE_VALUE]}) "
f"used for it at construction.")
Expand Down Expand Up @@ -3442,16 +3447,26 @@ def _get_port_for_socket(owner,
else:
proj_type = proj_spec[PROJECTION_TYPE]

# Get Port type if it is appropriate for the specified socket of the
# Projection's type
# Get Port type if it is appropriate for the specified socket of the Projection's type
s = next((s for s in port_types if
s.__name__ in getattr(proj_type.sockets, projection_socket)),
None)
# If there is a port_type for the projection_socket, try to get the actual Port and return it;
# otherwise return first in the list of allowable Port types for that socket
if s:
try:
# Return Port associated with projection_socket if proj_spec is an actual Projection
port = getattr(proj_spec, projection_socket)
return port
if proj_spec.initialization_status == ContextFlags.DEFERRED_INIT:
port = proj_spec._init_args[projection_socket]
if port is None:
raise AttributeError
elif isinstance(port, Mechanism):
# Mechanism specifiea as sender or receiver of Projection, so get corresponding primary Port
port = port.output_port
return port
else:
port = getattr(proj_spec, projection_socket)
return port
except AttributeError:
# Otherwise, return first port_type (s)
return s
Expand Down Expand Up @@ -3530,10 +3545,6 @@ def _get_port_for_socket(owner,
raise PortError("PROGRAM ERROR: {} attribute(s) not found on {}'s type ({})".
format(mech_port_attribute, mech.name, mech.__class__.__name__))

# # Get
# elif isinstance(port_spec, type) and issubclass(port_spec, Mechanism):


# Get port from Projection specification (exclude matrix spec in test as it can't be used to determine the port)
elif _is_projection_spec(port_spec, include_matrix_spec=False):
_validate_connection_request(owner=owner,
Expand Down
5 changes: 4 additions & 1 deletion psyneulink/core/components/projections/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@
from psyneulink.core.globals.registry import register_category, remove_instance_from_registry
from psyneulink.core.globals.socket import ConnectionInfo
from psyneulink.core.globals.utilities import \
ContentAddressableList, is_matrix, is_numeric, parse_valid_identifier
ContentAddressableList, is_matrix, is_numeric, parse_valid_identifier, convert_to_list

__all__ = [
'Projection_Base', 'projection_keywords', 'PROJECTION_SPEC_KEYWORDS',
Expand Down Expand Up @@ -1776,6 +1776,9 @@ def _parse_connection_specs(connectee_port_type,
mech=mech,
mech_port_attribute=mech_port_attribute,
projection_socket=projection_socket)
assert isinstance(port, Port) or all([p in port_types for p in convert_to_list(port)]), \
f'PROGRAM ERROR: ' \
f'projection._get_port_for_socket() returned {port} which is not a Port or allowed Port Type.'
except PortError as e:
raise ProjectionError(f"Problem with specification for {Port.__name__} in {Projection.__name__} "
f"specification{(' for ' + owner.name) if owner else ' '}: " + e.error_value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,9 @@ def test_projection_tuple_with_matrix_spec(self):

# ------------------------------------------------------------------------------------------------
# TEST 14
# Standalone Projection specification
# Standalone Projection specification with Mechanism as sender

def test_projection_list(self):
def test_projection_list_mech_as_send(self):
R2 = TransferMechanism(size=3)
P = MappingProjection(sender=R2)
T = TransferMechanism(
Expand All @@ -352,6 +352,23 @@ def test_projection_list(self):
assert len(T.input_port.defaults.variable) == 2
T.execute()

# ------------------------------------------------------------------------------------------------
# TEST 14b
# Standalone Projection specification with Port as sender

def test_projection_list_port_as_sender(self):
R2 = TransferMechanism(size=3)
P = MappingProjection(sender=R2.output_port)
T = TransferMechanism(
size=2,
input_ports=[P]
)
np.testing.assert_array_equal(T.defaults.variable, np.array([[0, 0]]))
assert len(T.input_ports) == 1
assert len(T.input_port.path_afferents[0].sender.defaults.variable) == 3
assert len(T.input_port.defaults.variable) == 2
T.execute()

# ------------------------------------------------------------------------------------------------
# TEST 15
# Projection specification in Tuple
Expand Down

0 comments on commit 03008f7

Please sign in to comment.