Skip to content

Commit

Permalink
update LayerNorm, RMSNorm, GroupNorm and other dropout models
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Nov 29, 2024
1 parent a98b31d commit 046dded
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
8 changes: 4 additions & 4 deletions brainstate/nn/_elementwise/_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
name: Optional[str] = None
) -> None:
super().__init__(name=name)
assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
self.prob = prob
self.channel_axis = channel_axis

Expand All @@ -112,7 +112,7 @@ def __call__(self, x):
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')

# generate mask
if fit_phase:
if fit_phase and self.prob < 1.:
dtype = u.math.get_dtype(x)
keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
return jnp.where(keep_mask,
Expand Down Expand Up @@ -396,7 +396,7 @@ def __init__(
name: Optional[str] = None
) -> None:
super().__init__(name=name)
assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
self.prob = prob
self.in_size = in_size
self.out_size = in_size
Expand All @@ -407,7 +407,7 @@ def init_state(self, batch_size=None, **kwargs):
def update(self, x):
dtype = u.math.get_dtype(x)
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
if fit_phase:
if fit_phase and self.prob < 1.:
if self.mask.value.shape != x.shape:
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
f"Please call `init_state()` method first.")
Expand Down
47 changes: 27 additions & 20 deletions brainstate/nn/_interaction/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _normalize(
y = y * mul
args = []
if weights is not None:
y, args = _scale_operation(y, weights.value, feature_shape)
y, args = _scale_operation(y, weights.value)
dtype = canonicalize_dtype(x, *args, dtype=dtype)
else:
assert var is None, 'mean and val must be both None or not None.'
Expand All @@ -223,13 +223,13 @@ def _normalize(
return jnp.asarray(y, dtype)


def _scale_operation(x: jax.Array, param: Dict, feature_shape: Axes):
def _scale_operation(x: jax.Array, param: Dict):
args = []
if 'scale' in param:
x = x * param['scale'].reshape(feature_shape)
x = x * param['scale']
args.append(param['scale'])
if 'bias' in param:
x = x + param['bias'].reshape(feature_shape)
x = x + param['bias']
args.append(param['bias'])
return x, args

Expand Down Expand Up @@ -258,8 +258,8 @@ def __init__(
super().__init__(name=name)

# parameters
self.in_size = tuple(in_size)
self.out_size = tuple(in_size)
self.in_size = in_size
self.out_size = in_size
self.affine = affine
self.bias_initializer = bias_initializer
self.scale_initializer = scale_initializer
Expand All @@ -271,13 +271,13 @@ def __init__(

# parameters about axis
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
self.feature_axes = _canonicalize_axes(len(in_size), feature_axis)
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups

# variables
feature_shape = tuple([(ax if i in self.feature_axes else 1)
for i, ax in enumerate(in_size)])
for i, ax in enumerate(self.in_size)])
if self.track_running_stats:
self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
Expand Down Expand Up @@ -499,7 +499,9 @@ class LayerNorm(Module):
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
reduction_axes: Axes for computing normalization statistics.
reduction_axes: Axes for computing normalization statistics. It is recommended
to use the negative integer, since when the batch dimension is used,
the reduction_axes may be wrong when using the positive integer.
feature_axes: Feature axes for learned bias and scaling.
axis_name: the axis name used to combine batch statistics from multiple
devices. See ``jax.pmap`` for a description of axis names (default: None).
Expand Down Expand Up @@ -537,14 +539,14 @@ def __init__(

# parameters about axis
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
self.feature_axes = _canonicalize_axes(len(in_size), feature_axes)
self.reduction_axes = reduction_axes
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups

# variables
feature_shape = tuple([(ax if i in self.feature_axes else 1)
for i, ax in enumerate(in_size)])
for i, ax in enumerate(self.in_size)])

weights = dict()
if use_scale:
Expand Down Expand Up @@ -622,7 +624,9 @@ class RMSNorm(Module):
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
scale_init: Initializer for scale, by default, one.
reduction_axes: Axes for computing normalization statistics.
reduction_axes: Axes for computing normalization statistics. It is recommended
to use the negative integer, since when the batch dimension is used,
the reduction_axes may be wrong when using the positive integer.
feature_axes: Feature axes for learned bias and scaling.
axis_name: the axis name used to combine batch statistics from multiple
devices. See ``jax.pmap`` for a description of axis names (default: None).
Expand Down Expand Up @@ -654,19 +658,20 @@ def __init__(
super().__init__()

self.in_size = in_size
self.out_size = in_size

# parameters about axis
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
self.feature_axes = _canonicalize_axes(len(in_size), feature_axes)
self.reduction_axes = reduction_axes
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
self.reduction_axes = (reduction_axes, ) if isinstance(reduction_axes, int) else reduction_axes
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups

# variables
feature_shape = tuple([(ax if i in self.feature_axes else 1)
for i, ax in enumerate(in_size)])
for i, ax in enumerate(self.in_size)])
if use_scale:
self.scale = ParamState(init.param(scale_init, feature_shape))
self.scale = ParamState({'scale': init.param(scale_init, feature_shape)})
else:
self.scale = None

Expand Down Expand Up @@ -755,6 +760,8 @@ class GroupNorm(Module):
feature axis. Furthermore, if the input used at call time has additional
leading axes compared to the data used for initialisation, for example due
to batching, then the reduction axes need to be defined explicitly.
It is recommended to use the negative integer, since when the batch dimension is used,
the reduction_axes may be wrong when using the positive integer.
axis_name: the axis name used to combine batch statistics from multiple
devices. See ``jax.pmap`` for a description of axis names (default: None).
This is only needed if the model is subdivided across devices, i.e. the
Expand Down Expand Up @@ -796,7 +803,8 @@ def __init__(

# parameters about axis
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
self.feature_axes = _canonicalize_axes(len(in_size), feature_axis)
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups

Expand All @@ -811,7 +819,7 @@ def __init__(
)

feature_shape = tuple([(ax if i in self.feature_axes else 1)
for i, ax in enumerate(in_size)])
for i, ax in enumerate(self.in_size)])
assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
num_features = feature_shape[0]
if group_size is not None:
Expand Down Expand Up @@ -851,7 +859,6 @@ def __init__(
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.reduction_axes = reduction_axes
self.use_fast_variance = use_fast_variance

def update(self, x, *, mask: Optional[jax.Array] = None):
Expand Down

0 comments on commit 046dded

Please sign in to comment.