Skip to content

Commit

Permalink
adding some docstring, fixing unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
epaxon committed Oct 24, 2023
1 parent 691e3fb commit 5998dae
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 28 deletions.
6 changes: 5 additions & 1 deletion src/lava/proc/graded/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# See: https://spdx.org/licenses/

import numpy as np
import typing as ty

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
Expand All @@ -16,6 +15,8 @@


class AbstractGradedVecModel(PyLoihiProcessModel):
"""Implementation of GradedVec"""

a_in = None
s_out = None

Expand Down Expand Up @@ -83,6 +84,7 @@ def run_spk(self) -> None:
@requires(CPU)
@tag('float')
class InvSqrtModelFloat(PyLoihiProcessModel):
"""Implementation of InvSqrt in floating point"""
a_in = LavaPyType(PyInPort.VEC_DENSE, float)
s_out = LavaPyType(PyOutPort.VEC_DENSE, float)

Expand Down Expand Up @@ -133,6 +135,8 @@ def inv_sqrt(s_fp, n_iters=5, b_fraction=12):
@requires(CPU)
@tag('fixed_pt')
class InvSqrtModelFP(PyLoihiProcessModel):
"""Implementation of InvSqrt in fixed point"""

a_in = LavaPyType(PyInPort.VEC_DENSE, np.int32, precision=24)
s_out = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24)
fp_base: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
Expand Down
10 changes: 0 additions & 10 deletions src/lava/proc/graded/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,13 @@
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import os
import numpy as np
import typing as ty
from typing import Any, Dict
from enum import IntEnum, unique

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel


def loihi2round(vv):
"""
Expand Down
27 changes: 14 additions & 13 deletions src/lava/proc/prodneuron/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,26 @@


class ProdNeuron(AbstractProcess):
"""ProdNeuron
Multiplies two graded inputs.
Parameters
----------
shape : tuple(int)
Number and topology of ProdNeuron neurons.
vth : int
Threshold
exp : int
Fixed-point base
"""
def __init__(
self,
shape: ty.Tuple[int, ...],
vth=1,
exp=0) -> None:
"""ProdNeuron
Multiplies two graded inputs.
Parameters
----------

shape : tuple(int)
Number and topology of ProdNeuron neurons.
vth : int
Threshold
exp : int
Fixed-point base
"""
super().__init__(shape=shape)

self.a_in1 = InPort(shape=shape)
Expand Down
5 changes: 1 addition & 4 deletions tests/lava/proc/graded/test_graded.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import unittest
import numpy as np
from scipy.sparse import csr_matrix, find
from scipy.sparse import csr_matrix

from lava.proc.graded.process import GradedVec, NormVecDelay, InvSqrt
from lava.proc.graded.models import inv_sqrt
from lava.proc.dense.process import Dense
from lava.proc.sparse.process import Sparse
from lava.proc import io
from lava.proc import embedded_io as eio

from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi2SimCfg
Expand Down Expand Up @@ -145,7 +144,6 @@ def test_invsqrt_calc(self):
class TestNormVecDelayProc(unittest.TestCase):

def test_norm_vec_delay_out1(self):
fp_base = 12 # base of the decimal point
weight_exp = 7
num_steps = 10

Expand Down Expand Up @@ -208,7 +206,6 @@ def test_norm_vec_delay_out1(self):
self.assertTrue(np.all(expected_out[:, :-1] == out_data[:, 1:]))

def test_norm_vec_delay_out2(self):
fp_base = 12 # base of the decimal point
weight_exp = 7
num_steps = 10

Expand Down

0 comments on commit 5998dae

Please sign in to comment.