Skip to content

Commit

Permalink
Sparsity netx pr (#238)
Browse files Browse the repository at this point in the history
* Added sparsity aware netx code

* Fixed misnamed variable

* sparsity_map should default as False rather than 0

* Was not passing sparse_synapse into create_dense. Fixed problem

* Have to sparsify the delay before converting weight to csr_matrix

* If sparse_synapse not in kwargs default it to False

* sparsity_map Network argument to sparse_fc_layer. Added simple unit tests to test if sparse_fc_layer creates Sparse types in lava

* Slight modification of docs for hdf5.Network

* Made fix to sparse hdf5.Network tests. Added netx test for sparse axonal delay

---------

Co-authored-by: Michael Jurado <[email protected]>
Co-authored-by: bamsumit <[email protected]>
  • Loading branch information
3 people authored Oct 3, 2023
1 parent 3754e24 commit c5fcf7b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 9 deletions.
40 changes: 36 additions & 4 deletions src/lava/lib/dl/netx/blocks/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.process.process import AbstractProcess
from lava.proc.dense.process import Dense as DenseSynapse
from lava.proc.sparse.process import Sparse as SparseSynapse
from lava.proc.sparse.process import DelaySparse as DelaySparseSynapse

from lava.proc.dense.process import DelayDense as DelayDenseSynapse
from lava.proc.conv.process import Conv as ConvSynapse
from scipy.sparse import csr_matrix


class AbstractBlock(AbstractProcess):
Expand Down Expand Up @@ -127,6 +131,8 @@ class Dense(AbstractBlock):
number of weight bits. Defaults to 8.
weight_exponent : int
weight exponent value. Defaults to 0.
sparse_synapse : bool
connection is sparse
input_message_bits : int, optional
number of message bits in input spike. Defaults to 0 meaning unary
spike.
Expand All @@ -139,16 +145,32 @@ def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None:
delay = kwargs.pop('delay', None)
num_weight_bits = kwargs.pop('num_weight_bits', 8)
weight_exponent = kwargs.pop('weight_exponent', 0)
sparse_synapse = kwargs.pop('sparse_synapse', False)

if delay is None:
self.synapse = DenseSynapse(
if sparse_synapse:
Synapse = SparseSynapse
weight = csr_matrix(weight)
else:
Synapse = DenseSynapse

self.synapse = Synapse(
weights=weight,
weight_exp=weight_exponent,
num_weight_bits=num_weight_bits,
num_message_bits=self.input_message_bits,
)
else:
self.synapse = DelayDenseSynapse(
# TODO test this in greater detail
if sparse_synapse:
Synapse = DelaySparseSynapse
delay[weight == 0] = 0
weight = csr_matrix(weight)
delay = csr_matrix(delay)
else:
Synapse = DelayDenseSynapse

self.synapse = Synapse(
weights=weight,
delays=delay.astype(int),
max_delay=62,
Expand Down Expand Up @@ -199,6 +221,8 @@ class ComplexDense(AbstractBlock):
real weight exponent value. Defaults to 0.
weight_exponent_imag : int
imag weight exponent value. Defaults to 0.
sparse_synapse : bool
connection is sparse
input_message_bits : int, optional
number of message bits in input spike. Defaults to 0 meaning unary
spike.
Expand All @@ -214,15 +238,23 @@ def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None:
weight_exponent_imag = kwargs.pop('weight_exponent_imag', 0)
weight_real = kwargs.pop('weight_real')
weight_imag = kwargs.pop('weight_imag')
sparse_synapse = kwargs.pop('sparse_synapse', False)

self.neuron = self._neuron(None)
self.real_synapse = DenseSynapse(
if sparse_synapse:
Synapse = SparseSynapse
weight_real = csr_matrix(weight_real)
weight_imag = csr_matrix(weight_imag)
else:
Synapse = DenseSynapse

self.real_synapse = Synapse(
weights=weight_real,
weight_exp=weight_exponent_real,
num_weight_bits=num_weight_bits_real,
num_message_bits=self.input_message_bits,
)
self.imag_synapse = DenseSynapse(
self.imag_synapse = Synapse(
weights=weight_imag,
weight_exp=weight_exponent_imag,
num_weight_bits=num_weight_bits_imag,
Expand Down
19 changes: 14 additions & 5 deletions src/lava/lib/dl/netx/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class Network(AbstractProcess):
neuron's reset parameter. None means no reset. Defaults to None.
reset_offset: int
determines the phase shift of network reset if enabled. Defaults to 0.
sparse_fc_layer : boolean, optional
If True, all fully-connected layer synapses will be interpreted as
Sparse types in Lava.
"""

def __init__(self,
Expand All @@ -62,7 +65,8 @@ def __init__(self,
input_message_bits: Optional[int] = 0,
input_shape: Optional[Tuple[int, ...]] = None,
reset_interval: Optional[int] = None,
reset_offset: int = 0) -> None:
reset_offset: int = 0,
sparse_fc_layer: bool = False) -> None:
super().__init__(net_config=net_config,
num_layers=num_layers,
input_message_bits=input_message_bits)
Expand All @@ -75,6 +79,7 @@ def __init__(self,
self.input_shape = input_shape
self.reset_interval = reset_interval
self.reset_offset = reset_offset
self.sparse_fc_layer = sparse_fc_layer

self.net_str = ''
self.layers = self._create()
Expand Down Expand Up @@ -298,7 +303,8 @@ def create_input(layer_config: h5py.Group,
def create_dense(layer_config: h5py.Group,
input_message_bits: int = 0,
reset_interval: Optional[int] = None,
reset_offset: int = 0) -> Tuple[Dense, str]:
reset_offset: int = 0,
sparse_synapse: bool = 0) -> Tuple[Dense, str]:
"""Creates dense layer from layer configuration
Parameters
Expand Down Expand Up @@ -361,7 +367,8 @@ def create_dense(layer_config: h5py.Group,
'weight_exponent_imag': weight_exponent_imag,
'sign_mode_real': sign_mode_real,
'sign_mode_imag': sign_mode_imag,
'input_message_bits': input_message_bits}
'input_message_bits': input_message_bits,
"sparse_synapse": sparse_synapse}

proc = ComplexDense(**params)

Expand All @@ -380,7 +387,8 @@ def create_dense(layer_config: h5py.Group,
'num_weight_bits': num_weight_bits,
'weight_exponent': weight_exponent,
'sign_mode': sign_mode,
'input_message_bits': input_message_bits}
'input_message_bits': input_message_bits,
"sparse_synapse": sparse_synapse}

if 'delay' in layer_config.keys():
delay = layer_config['delay']
Expand Down Expand Up @@ -591,7 +599,8 @@ def _create(self) -> List[AbstractProcess]:
layer_config=layer_config[i],
input_message_bits=input_message_bits,
reset_interval=reset_interval,
reset_offset=reset_offset)
reset_offset=reset_offset,
sparse_synapse=self.sparse_fc_layer)
if i >= self.skip_layers:
layers.append(layer)
reset_offset += 1
Expand Down
20 changes: 20 additions & 0 deletions tests/lava/lib/dl/netx/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lava.proc.rf.process import RF
from lava.proc.rf_iz.process import RF_IZ
from lava.proc.sdn.process import Sigma, Delta, SigmaDelta
from lava.proc.sparse.process import Sparse
from lava.proc.conv import utils

from lava.lib.dl.netx.blocks.process import Dense, Conv, Input, ComplexDense,\
Expand Down Expand Up @@ -330,6 +331,25 @@ def test_dense(self) -> None:
f'Error was {s_error}.'
)

def test_sparse(self) -> None:
"""Tests RF dense block can use sparse types."""
rf_params = {'vth': 25,
'period': 11,
'state_exp': 6,
'decay_bits': 12,
'alpha': .05}

dense_blk = ComplexDense(
shape=(256,),
neuron_params={'neuron_proc': RF, **rf_params},
weight_real=np.load(root + '/gts/complex_dense/weight_r.npy'),
weight_imag=np.load(root + '/gts/complex_dense/weight_img.npy'),
sparse_synapse=True
)

self.assertTrue(isinstance(dense_blk.real_synapse, Sparse))
self.assertTrue(isinstance(dense_blk.imag_synapse, Sparse))


class TestSDNBlocks(unittest.TestCase):
def test_input(self) -> None:
Expand Down
31 changes: 31 additions & 0 deletions tests/lava/lib/dl/netx/test_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from lava.magma.core.run_conditions import RunSteps
from lava.proc import io
from lava.proc.conv import utils
from lava.proc.sparse.process import Sparse, DelaySparse
from lava.proc.dense.process import Dense, DelayDense

from lava.lib.dl import netx

Expand Down Expand Up @@ -174,6 +176,19 @@ def test_pilotnet_sdnn(self) -> None:
f'Error was {error}.'
)

def test_sparse_pilotnet_sdnn(self) -> None:
"""Tests sparse_fc_layer Network arg on Dense blocks"""
net_config = root + '/gts/pilotnet_sdnn/network.net'
net = netx.hdf5.Network(net_config=net_config, sparse_fc_layer=True)
dense_layers = [layer for layer in net.layers
if isinstance(layer, netx.blocks.process.Dense)]

self.assertTrue(
np.all([
isinstance(layer.synapse, Sparse) for layer in dense_layers
])
)

def test_axonal_delay_ntidigits(self) -> None:
"""Tests the output of ntidigits hdf5 description. This network
consists of axonal delay. So this tests specifically tests for
Expand Down Expand Up @@ -222,6 +237,22 @@ def test_axonal_delay_ntidigits(self) -> None:
f'Error was {error}.'
)

def test_sparse_axonal_delay_ntidigits(self) -> None:
"""Tests that sparse axonal delays work on Dense Blocks."""
net_config = root + '/gts/ntidigits/ntidigits.net'
# skipping the last average layer which is not suppprted
net = netx.hdf5.Network(net_config=net_config, num_layers=5,
sparse_fc_layer=True)
dense_layers = [layer for layer in net.layers
if isinstance(layer, netx.blocks.process.Dense)]

self.assertTrue(
np.all([
isinstance(layer.synapse, (Sparse, DelaySparse))
for layer in dense_layers
])
)


if __name__ == '__main__':
unittest.main()

0 comments on commit c5fcf7b

Please sign in to comment.