Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable get/set on learning rule parameters #622

Merged
merged 17 commits into from
Feb 17, 2023
149 changes: 95 additions & 54 deletions src/lava/magma/compiler/builders/py_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@

from lava.magma.compiler.channels.interfaces import AbstractCspPort
from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort
from lava.magma.compiler.utils import (PortInitializer, VarInitializer,
VarPortInitializer)
from lava.magma.compiler.utils import (
PortInitializer,
VarInitializer,
VarPortInitializer,
)
from lava.magma.core.model.py.model import AbstractPyProcessModel
from lava.magma.core.model.py.ports import (AbstractPyIOPort,
IdentityTransformer, PyInPort,
PyOutPort, PyRefPort, PyVarPort,
VirtualPortTransformer)
from lava.magma.core.model.py.ports import (
AbstractPyIOPort,
IdentityTransformer,
PyInPort,
PyOutPort,
PyRefPort,
PyVarPort,
VirtualPortTransformer,
)
from lava.magma.core.model.py.type import LavaPyType


Expand Down Expand Up @@ -44,14 +52,12 @@ class variables of a PyProcessModel, creates the corresponding data type
"""

def __init__(
self,
proc_model: ty.Type[AbstractPyProcessModel],
model_id: int,
proc_params: ty.Dict[str, ty.Any] = None):
super().__init__(
proc_model=proc_model,
model_id=model_id
)
self,
proc_model: ty.Type[AbstractPyProcessModel],
model_id: int,
proc_params: ty.Dict[str, ty.Any] = None,
):
super().__init__(proc_model=proc_model, model_id=model_id)
if not issubclass(proc_model, AbstractPyProcessModel):
raise AssertionError("Is not a subclass of AbstractPyProcessModel")
self.vars: ty.Dict[str, VarInitializer] = {}
Expand All @@ -77,10 +83,10 @@ def check_all_vars_and_ports_set(self):
attr = getattr(self.proc_model, attr_name)
if isinstance(attr, LavaPyType):
if (
attr_name not in self.vars
and attr_name not in self.py_ports
and attr_name not in self.ref_ports
and attr_name not in self.var_ports
attr_name not in self.vars
and attr_name not in self.py_ports
and attr_name not in self.ref_ports
and attr_name not in self.var_ports
):
raise AssertionError(
f"No LavaPyType '{attr_name}' found in ProcModel "
Expand Down Expand Up @@ -187,8 +193,12 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
proc_name = self.proc_model.implements_process.__name__
for port_name in new_ports:
if not hasattr(self.proc_model, port_name):
raise AssertionError("PyProcessModel '{}' has \
no port named '{}'.".format(proc_name, port_name))
raise AssertionError(
"PyProcessModel '{}' has \
no port named '{}'.".format(
proc_name, port_name
)
)

if port_name in self.csp_ports:
self.csp_ports[port_name].extend(new_ports[port_name])
Expand All @@ -209,9 +219,9 @@ def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort):
a CSP port
"""
# Add or update the mapping
self._csp_port_map.setdefault(
csp_port.name, {}
).update({py_port_id: csp_port})
self._csp_port_map.setdefault(csp_port.name, {}).update(
{py_port_id: csp_port}
)

def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""Set RS CSP Ports
Expand Down Expand Up @@ -274,17 +284,21 @@ def build(self):
csp_ports = [csp_ports]

if issubclass(port_cls, PyInPort):
transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
transformer = (
VirtualPortTransformer(
self._csp_port_map[name], p.transform_funcs
)
if p.transform_funcs
else IdentityTransformer()
)
port_cls = ty.cast(ty.Type[PyInPort], lt.cls)
port = port_cls(csp_ports, pm, p.shape, lt.d_type, transformer)
elif issubclass(port_cls, PyOutPort):
port = port_cls(csp_ports, pm, p.shape, lt.d_type)
else:
raise AssertionError("port_cls must be of type PyInPort or "
"PyOutPort")
raise AssertionError(
"port_cls must be of type PyInPort or " "PyOutPort"
)

# Create dynamic PyPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -300,18 +314,28 @@ def build(self):
csp_send = None
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
csp_recv = (
csp_ports[0]
if isinstance(csp_ports[0], CspRecvPort)
else csp_ports[1]
)
csp_send = (
csp_ports[0]
if isinstance(csp_ports[0], CspSendPort)
else csp_ports[1]
)

transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
transformer = (
VirtualPortTransformer(
self._csp_port_map[name], p.transform_funcs
)
if p.transform_funcs
else IdentityTransformer()
)

port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type,
transformer)
port = port_cls(
csp_send, csp_recv, pm, p.shape, lt.d_type, transformer
)

# Create dynamic RefPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -327,19 +351,34 @@ def build(self):
csp_send = None
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
csp_recv = (
csp_ports[0]
if isinstance(csp_ports[0], CspRecvPort)
else csp_ports[1]
)
csp_send = (
csp_ports[0]
if isinstance(csp_ports[0], CspSendPort)
else csp_ports[1]
)

transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
transformer = (
VirtualPortTransformer(
self._csp_port_map[name], p.transform_funcs
)
if p.transform_funcs
else IdentityTransformer()
)

port = port_cls(
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type,
transformer)
p.var_name,
csp_send,
csp_recv,
pm,
p.shape,
p.d_type,
transformer,
)

# Create dynamic VarPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -361,13 +400,15 @@ def build(self):
if issubclass(lt.cls, np.ndarray):
var = lt.cls(v.shape, lt.d_type)
var[:] = v.value
elif issubclass(lt.cls, (int, float)):
elif issubclass(lt.cls, (int, float, str)):
var = v.value
else:
raise NotImplementedError("Cannot initiliaze variable "
"datatype, \
only subclasses of int and float are \
supported")
raise NotImplementedError(
"Cannot initiliaze variable "
"datatype, \
only subclasses of int, float and str are \
supported"
)

# Create dynamic variable attribute on ProcModel
setattr(pm, name, var)
Expand Down
2 changes: 1 addition & 1 deletion src/lava/magma/core/learning/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
W_WEIGHTS_S = W_WEIGHTS_U + 1

# Unsigned width of tag 2
W_TAG_2_U = 7
W_TAG_2_U = 8

# Signed width of tag 2
W_TAG_2_S = W_TAG_2_U + 1
Expand Down
Loading