Skip to content

Commit

Permalink
Enable get/set on learning rule parameters (#622)
Browse files Browse the repository at this point in the history
* pre-traces and string for floating

* y params and tests

* tests for fixed pt

* minor cleanup

* lint

* adde lr to initial parameters of LearningDense

* rm unused imports

* avoid using learning lif

* temp: just one stdp test

* enable 2f learning rules

* revert neuron.py

* minor change

* lint

* clean up str representation of learning rule

* cleanup

* lint
  • Loading branch information
weidel-p authored Feb 17, 2023
1 parent 4283428 commit 8cb6787
Show file tree
Hide file tree
Showing 13 changed files with 2,196 additions and 362 deletions.
149 changes: 95 additions & 54 deletions src/lava/magma/compiler/builders/py_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@

from lava.magma.compiler.channels.interfaces import AbstractCspPort
from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort
from lava.magma.compiler.utils import (PortInitializer, VarInitializer,
VarPortInitializer)
from lava.magma.compiler.utils import (
PortInitializer,
VarInitializer,
VarPortInitializer,
)
from lava.magma.core.model.py.model import AbstractPyProcessModel
from lava.magma.core.model.py.ports import (AbstractPyIOPort,
IdentityTransformer, PyInPort,
PyOutPort, PyRefPort, PyVarPort,
VirtualPortTransformer)
from lava.magma.core.model.py.ports import (
AbstractPyIOPort,
IdentityTransformer,
PyInPort,
PyOutPort,
PyRefPort,
PyVarPort,
VirtualPortTransformer,
)
from lava.magma.core.model.py.type import LavaPyType


Expand Down Expand Up @@ -44,14 +52,12 @@ class variables of a PyProcessModel, creates the corresponding data type
"""

def __init__(
self,
proc_model: ty.Type[AbstractPyProcessModel],
model_id: int,
proc_params: ty.Dict[str, ty.Any] = None):
super().__init__(
proc_model=proc_model,
model_id=model_id
)
self,
proc_model: ty.Type[AbstractPyProcessModel],
model_id: int,
proc_params: ty.Dict[str, ty.Any] = None,
):
super().__init__(proc_model=proc_model, model_id=model_id)
if not issubclass(proc_model, AbstractPyProcessModel):
raise AssertionError("Is not a subclass of AbstractPyProcessModel")
self.vars: ty.Dict[str, VarInitializer] = {}
Expand All @@ -77,10 +83,10 @@ def check_all_vars_and_ports_set(self):
attr = getattr(self.proc_model, attr_name)
if isinstance(attr, LavaPyType):
if (
attr_name not in self.vars
and attr_name not in self.py_ports
and attr_name not in self.ref_ports
and attr_name not in self.var_ports
attr_name not in self.vars
and attr_name not in self.py_ports
and attr_name not in self.ref_ports
and attr_name not in self.var_ports
):
raise AssertionError(
f"No LavaPyType '{attr_name}' found in ProcModel "
Expand Down Expand Up @@ -187,8 +193,12 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
proc_name = self.proc_model.implements_process.__name__
for port_name in new_ports:
if not hasattr(self.proc_model, port_name):
raise AssertionError("PyProcessModel '{}' has \
no port named '{}'.".format(proc_name, port_name))
raise AssertionError(
"PyProcessModel '{}' has \
no port named '{}'.".format(
proc_name, port_name
)
)

if port_name in self.csp_ports:
self.csp_ports[port_name].extend(new_ports[port_name])
Expand All @@ -209,9 +219,9 @@ def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort):
a CSP port
"""
# Add or update the mapping
self._csp_port_map.setdefault(
csp_port.name, {}
).update({py_port_id: csp_port})
self._csp_port_map.setdefault(csp_port.name, {}).update(
{py_port_id: csp_port}
)

def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""Set RS CSP Ports
Expand Down Expand Up @@ -274,17 +284,21 @@ def build(self):
csp_ports = [csp_ports]

if issubclass(port_cls, PyInPort):
transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
transformer = (
VirtualPortTransformer(
self._csp_port_map[name], p.transform_funcs
)
if p.transform_funcs
else IdentityTransformer()
)
port_cls = ty.cast(ty.Type[PyInPort], lt.cls)
port = port_cls(csp_ports, pm, p.shape, lt.d_type, transformer)
elif issubclass(port_cls, PyOutPort):
port = port_cls(csp_ports, pm, p.shape, lt.d_type)
else:
raise AssertionError("port_cls must be of type PyInPort or "
"PyOutPort")
raise AssertionError(
"port_cls must be of type PyInPort or " "PyOutPort"
)

# Create dynamic PyPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -300,18 +314,28 @@ def build(self):
csp_send = None
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
csp_recv = (
csp_ports[0]
if isinstance(csp_ports[0], CspRecvPort)
else csp_ports[1]
)
csp_send = (
csp_ports[0]
if isinstance(csp_ports[0], CspSendPort)
else csp_ports[1]
)

transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
transformer = (
VirtualPortTransformer(
self._csp_port_map[name], p.transform_funcs
)
if p.transform_funcs
else IdentityTransformer()
)

port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type,
transformer)
port = port_cls(
csp_send, csp_recv, pm, p.shape, lt.d_type, transformer
)

# Create dynamic RefPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -327,19 +351,34 @@ def build(self):
csp_send = None
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
csp_recv = (
csp_ports[0]
if isinstance(csp_ports[0], CspRecvPort)
else csp_ports[1]
)
csp_send = (
csp_ports[0]
if isinstance(csp_ports[0], CspSendPort)
else csp_ports[1]
)

transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
transformer = (
VirtualPortTransformer(
self._csp_port_map[name], p.transform_funcs
)
if p.transform_funcs
else IdentityTransformer()
)

port = port_cls(
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type,
transformer)
p.var_name,
csp_send,
csp_recv,
pm,
p.shape,
p.d_type,
transformer,
)

# Create dynamic VarPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -361,13 +400,15 @@ def build(self):
if issubclass(lt.cls, np.ndarray):
var = lt.cls(v.shape, lt.d_type)
var[:] = v.value
elif issubclass(lt.cls, (int, float)):
elif issubclass(lt.cls, (int, float, str)):
var = v.value
else:
raise NotImplementedError("Cannot initiliaze variable "
"datatype, \
only subclasses of int and float are \
supported")
raise NotImplementedError(
"Cannot initiliaze variable "
"datatype, \
only subclasses of int, float and str are \
supported"
)

# Create dynamic variable attribute on ProcModel
setattr(pm, name, var)
Expand Down
2 changes: 1 addition & 1 deletion src/lava/magma/core/learning/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
W_WEIGHTS_S = W_WEIGHTS_U + 1

# Unsigned width of tag 2
W_TAG_2_U = 7
W_TAG_2_U = 8

# Signed width of tag 2
W_TAG_2_S = W_TAG_2_U + 1
Expand Down
Loading

0 comments on commit 8cb6787

Please sign in to comment.