Skip to content

Commit

Permalink
Merge pull request #430 from yygf123/master
Browse files Browse the repository at this point in the history
Add new tests
  • Loading branch information
chaoming0625 authored Aug 2, 2023
2 parents 90ff3bc + d996f86 commit 7d086d2
Show file tree
Hide file tree
Showing 9 changed files with 961 additions and 32 deletions.
56 changes: 34 additions & 22 deletions brainpy/_src/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class _GeneralConv(Layer):
The name of the object.
"""

supported_modes = (bm.TrainingMode, bm.BatchingMode)
supported_modes = (bm.TrainingMode, bm.BatchingMode,bm.NonBatchingMode)

def __init__(
self,
Expand All @@ -101,7 +101,6 @@ def __init__(
name: str = None,
):
super(_GeneralConv, self).__init__(name=name, mode=mode)
check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode), self.name)

self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
Expand Down Expand Up @@ -149,14 +148,18 @@ def __init__(
self.b = bm.TrainVar(self.b)

def _check_input_dim(self, x):
if x.ndim != self.num_spatial_dims + 2:
raise ValueError(f"expected {self.num_spatial_dims + 2}D input (got {x.ndim}D input)")
if x.ndim != self.num_spatial_dims + 2 and x.ndim != self.num_spatial_dims + 1:
raise ValueError(f"expected {self.num_spatial_dims + 2}D or {self.num_spatial_dims + 1}D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")

def update(self, x):
self._check_input_dim(x)
nonbatching=False
if x.ndim == self.num_spatial_dims + 1:
nonbatching=True
x=x.unsqueeze(0)
w = self.w.value
if self.mask is not None:
try:
Expand All @@ -172,7 +175,10 @@ def update(self, x):
rhs_dilation=self.rhs_dilation,
feature_group_count=self.groups,
dimension_numbers=self.dimension_numbers)
return y if self.b is None else (y + self.b.value)
if nonbatching:
return y[0] if self.b is None else (y + self.b.value)[0]
else:
return y if self.b is None else (y + self.b.value)

def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
Expand Down Expand Up @@ -265,8 +271,8 @@ def __init__(
name=name)

def _check_input_dim(self, x):
if x.ndim != 3:
raise ValueError(f"expected 3D input (got {x.ndim}D input)")
if x.ndim != 3 and x.ndim !=2 :
raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
Expand Down Expand Up @@ -358,8 +364,8 @@ def __init__(
name=name)

def _check_input_dim(self, x):
if x.ndim != 4:
raise ValueError(f"expected 4D input (got {x.ndim}D input)")
if x.ndim != 4 and x.ndim !=3:
raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
Expand Down Expand Up @@ -451,8 +457,8 @@ def __init__(
name=name)

def _check_input_dim(self, x):
if x.ndim != 5:
raise ValueError(f"expected 5D input (got {x.ndim}D input)")
if x.ndim != 5 and x.ndim != 4:
raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
Expand All @@ -464,7 +470,7 @@ def _check_input_dim(self, x):


class _GeneralConvTranspose(Layer):
supported_modes = (bm.TrainingMode, bm.BatchingMode)
supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode)

def __init__(
self,
Expand All @@ -481,9 +487,9 @@ def __init__(
mode: bm.Mode = None,
name: str = None,
):
super().__init__(name=name, mode=mode)
super(_GeneralConvTranspose,self).__init__(name=name, mode=mode)

assert self.mode.is_parent_of(bm.TrainingMode, bm.BatchingMode)
assert self.mode.is_parent_of(bm.TrainingMode, bm.BatchingMode,bm.NonBatchingMode)

self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
Expand Down Expand Up @@ -530,7 +536,10 @@ def _check_input_dim(self, x):

def update(self, x):
self._check_input_dim(x)

nonbatching = False
if x.ndim==self.num_spatial_dims + 1:
nonbatching=True
x=x.unsqueeze(0)
w = self.w.value
if self.mask is not None:
try:
Expand All @@ -545,7 +554,10 @@ def update(self, x):
precision=self.precision,
rhs_dilation=None,
dimension_numbers=self.dimension_numbers)
return y if self.b is None else (y + self.b.value)
if nonbatching:
return y[0] if self.b is None else (y + self.b.value)[0]
else:
return y if self.b is None else (y + self.b.value)

def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
Expand Down Expand Up @@ -608,8 +620,8 @@ def __init__(
)

def _check_input_dim(self, x):
if x.ndim != 3:
raise ValueError(f"expected 3D input (got {x.ndim}D input)")
if x.ndim != 3 and x.ndim != 2:
raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
Expand Down Expand Up @@ -664,8 +676,8 @@ def __init__(
)

def _check_input_dim(self, x):
if x.ndim != 4:
raise ValueError(f"expected 4D input (got {x.ndim}D input)")
if x.ndim != 4 and x.ndim != 3:
raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
Expand Down Expand Up @@ -726,8 +738,8 @@ def __init__(
)

def _check_input_dim(self, x):
if x.ndim != 5:
raise ValueError(f"expected 5D input (got {x.ndim}D input)")
if x.ndim != 5 and x.ndim != 4:
raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
3 changes: 2 additions & 1 deletion brainpy/_src/dnn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BatchNorm(Layer):
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
"""
supported_modes = (bm.BatchingMode, bm.TrainingMode)

def __init__(
self,
Expand All @@ -100,7 +101,7 @@ def __init__(
name: Optional[str] = None,
):
super(BatchNorm, self).__init__(name=name, mode=mode)
check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name)
# check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name)

# parameters
self.num_features = num_features
Expand Down
32 changes: 26 additions & 6 deletions brainpy/_src/dnn/rnncells.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class RNNCell(Layer):
Parameters
----------
num_in: int
The dimension of the input vector
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
Expand Down Expand Up @@ -111,7 +113,7 @@ def __init__(
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False)
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out,), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train
Expand Down Expand Up @@ -149,6 +151,8 @@ class GRUCell(Layer):
Parameters
----------
num_in: int
The dimension of the input vector
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
Expand Down Expand Up @@ -280,6 +284,8 @@ class LSTMCell(Layer):
Parameters
----------
num_in: int
The dimension of the input vector
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
Expand Down Expand Up @@ -363,15 +369,15 @@ def reset_state(self, batch_size=None):
self.state[:] = self.state2train

def update(self, x):
h, c = jnp.split(self.state.value, 2, axis=-1)
h, c = bm.split(self.state.value, 2, axis=-1)
gated = x @ self.Wi
if self.b is not None:
gated += self.b
gated += h @ self.Wh
i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1)
c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * self.activation(g)
h = bm.sigmoid(o) * self.activation(c)
self.state.value = jnp.concatenate([h, c], axis=-1)
self.state.value = bm.concatenate([h, c], axis=-1)
return h

@property
Expand Down Expand Up @@ -531,7 +537,8 @@ def __init__(
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer, )
b_initializer=b_initializer,
mode=mode)
self.hidden_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims,
in_channels=out_channels,
out_channels=out_channels * 4,
Expand All @@ -542,7 +549,8 @@ def __init__(
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer, )
b_initializer=b_initializer,
mode=mode)
self.reset_state()

def reset_state(self, batch_size: int = 1):
Expand Down Expand Up @@ -599,6 +607,10 @@ def __init__(
):
"""Constructs a 1-D convolutional LSTM.
Input: [Batch_Size, Input_Data_Size, Input_Channel_Size]
Output: [Batch_Size, Output_Data_Size, Output_Channel_Size]
Args:
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
Expand Down Expand Up @@ -656,6 +668,10 @@ def __init__(
):
"""Constructs a 2-D convolutional LSTM.
Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2, Input_Channel_Size]
Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2 , Output_Channel_Size]
Args:
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
Expand Down Expand Up @@ -713,6 +729,10 @@ def __init__(
):
"""Constructs a 3-D convolutional LSTM.
Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2,Input_Data_Size_Dim3 ,Input_Channel_Size]
Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2,Output_Data_Size_Dim3,Output_Channel_Size]
Args:
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/dnn/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest
import brainpy as bp
Expand Down
Loading

0 comments on commit 7d086d2

Please sign in to comment.