From d3e9774d73894a8c6d82bd67534f7e57141c9b29 Mon Sep 17 00:00:00 2001 From: Joyesh Mishra Date: Mon, 9 Jan 2023 11:34:48 -0800 Subject: [PATCH] Add var port polling in host phase to flush out any pending reads (#563) * Add var port polling in host phase to flush out any pending reads * Added unit test --- src/lava/magma/core/model/py/model.py | 3 +- .../lava/magma/runtime/test_ref_var_ports.py | 67 ++++++++++++++++++- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/lava/magma/core/model/py/model.py b/src/lava/magma/core/model/py/model.py index d1f025d0e..31ada187f 100644 --- a/src/lava/magma/core/model/py/model.py +++ b/src/lava/magma/core/model/py/model.py @@ -435,7 +435,8 @@ def add_ports_for_polling(self): Add various ports to poll for communication on ports """ if enum_equal(self.phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \ - enum_equal(self.phase, PyLoihiProcessModel.Phase.POST_MGMT): + enum_equal(self.phase, PyLoihiProcessModel.Phase.POST_MGMT) \ + or enum_equal(self.phase, PyLoihiProcessModel.Phase.HOST): for var_port in self.var_ports: for csp_port in var_port.csp_ports: if isinstance(csp_port, CspRecvPort): diff --git a/tests/lava/magma/runtime/test_ref_var_ports.py b/tests/lava/magma/runtime/test_ref_var_ports.py index 4291c7fdb..13db2689f 100644 --- a/tests/lava/magma/runtime/test_ref_var_ports.py +++ b/tests/lava/magma/runtime/test_ref_var_ports.py @@ -4,9 +4,9 @@ import numpy as np import unittest -from lava.magma.core.decorator import implements, requires +from lava.magma.core.decorator import implements, requires, tag from lava.magma.core.model.py.model import PyLoihiProcessModel -from lava.magma.core.model.py.ports import PyRefPort, PyVarPort +from lava.magma.core.model.py.ports import PyRefPort, PyVarPort, PyInPort from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.model.sub.model import AbstractSubProcessModel from lava.magma.core.process.ports.ports import RefPort, VarPort @@ -319,5 +319,68 @@ def test_hierarchical_ref_ports(self): recv.stop() +class TestPortsInProcess(unittest.TestCase): + """Tests PyPorts in Processes.""" + + def test_refport_write_to_varport(self) -> None: + """Tests writing from a RefPort to a VarPort.""" + num_steps = 1 + shape = (4, 3, 2) + np.random.seed(7739) + input_data = np.random.randint(256, size=shape) + + source = RefPortWriteProcess(data=input_data) + sink = VarPortProcess(data=np.zeros(shape)) + + source.ref_port.connect(sink.var_port) + + try: + sink.run(condition=RunSteps(num_steps=num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + output = sink.data.get() + finally: + sink.stop() + + np.testing.assert_array_equal(output, input_data) + + +class RefPortWriteProcess(AbstractProcess): + def __init__(self, data: np.ndarray) -> None: + super().__init__() + self.data = Var(shape=data.shape, init=data) + self.ref_port = RefPort(shape=data.shape) + + +class VarPortProcess(AbstractProcess): + def __init__(self, data: np.ndarray) -> None: + super().__init__() + self.data = Var(shape=data.shape, init=data) + self.var_port = VarPort(self.data) + + +@implements(proc=RefPortWriteProcess, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyRefPortWriteProcessModelFloat(PyLoihiProcessModel): + ref_port: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, np.int32) + data: np.ndarray = LavaPyType(np.ndarray, np.int32) + + def post_guard(self): + return True + + def run_post_mgmt(self): + self.ref_port.write(self.data) + self.log.info("Sent output data of RefPortWriteProcess: ", + str(self.data)) + + +@implements(proc=VarPortProcess, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyVarPortProcessModelFloat(PyLoihiProcessModel): + var_port: PyInPort = LavaPyType(PyVarPort.VEC_DENSE, np.int32) + data: np.ndarray = LavaPyType(np.ndarray, np.int32) + + if __name__ == '__main__': unittest.main()