Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Bayesian neural network architectures with training requisites #126

Merged
merged 46 commits into from
Jun 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
8573fdd
asdnet_added
Aakanksha-Rana Nov 18, 2020
d33c2d0
Create bayesian_utils.py
Aakanksha-Rana Mar 22, 2021
f289a3e
Create bayesian-meshnet.py
Aakanksha-Rana Mar 23, 2021
0738d29
Merge branch 'master' into bayesian-training
satra Mar 25, 2021
2d79c20
Create bayesian_vnet_semi.py
Aakanksha-Rana May 20, 2021
5deedf8
Create bayesian_vnet.py
Aakanksha-Rana May 20, 2021
8766a3d
Create vnet.py
Aakanksha-Rana May 20, 2021
39531dd
Merge branch 'master' into bayesian-training
satra Jun 13, 2021
20a8930
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
89a12ab
Merge branch 'master' into bayesian-training
satra Jun 13, 2021
27588c6
fix: reorder positional arg
satra Jun 13, 2021
c47f728
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
39f22f2
Merge branch 'master' into bayesian-training
satra Jun 13, 2021
7748f34
Update vnet.py
Aakanksha-Rana Jun 18, 2021
e1e9442
Update models_test.py
Aakanksha-Rana Jun 18, 2021
ff96c14
Update vnet.py
Aakanksha-Rana Jun 19, 2021
215d2ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
62b4ab2
Update bayesian_vnet_semi.py
Aakanksha-Rana Jun 19, 2021
5894d85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
649f4db
Update bayesian_vnet_semi.py
Aakanksha-Rana Jun 19, 2021
d7f5eb4
Update bayesian_utils.py
Aakanksha-Rana Jun 19, 2021
dc450bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
9299fb5
Update bayesian_vnet_semi.py
Aakanksha-Rana Jun 19, 2021
5ddb2e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
7073c84
Update bayesian_vnet.py
Aakanksha-Rana Jun 19, 2021
f3eada5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
b119f84
Update models_test.py
Aakanksha-Rana Jun 19, 2021
b703288
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
af9d99b
Add files via upload
Aakanksha-Rana Jun 19, 2021
e97c1ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
2a0100d
Update models_test.py
Aakanksha-Rana Jun 19, 2021
c7c161a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
8692b3e
Update models_test.py
Aakanksha-Rana Jun 19, 2021
28ed033
Update models_test.py
Aakanksha-Rana Jun 19, 2021
18a5753
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
c671a2a
Update bayesian_vnet_semi.py
Aakanksha-Rana Jun 19, 2021
ed220ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
2ba38f8
Update bayesian_utils.py
Aakanksha-Rana Jun 19, 2021
8eeda35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
b4777ca
Update bayesian-meshnet.py
Aakanksha-Rana Jun 19, 2021
f40ac54
Update bayesian_utils.py
Aakanksha-Rana Jun 19, 2021
313685e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2021
11ddbce
Update bayesian_utils.py
Aakanksha-Rana Jun 19, 2021
8660197
Merge branch 'master' into bayesian-training
satra Jun 19, 2021
7d24326
Delete asdnet.py
Aakanksha-Rana Jun 19, 2021
e1072d5
using import paths relative to file
satra Jun 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions nobrainer/bayesian_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf2
import tensorflow_probability as tfp
from tensorflow_probability.python.distributions import (
deterministic as deterministic_lib,
)
from tensorflow_probability.python.distributions import independent as independent_lib
from tensorflow_probability.python.distributions import normal as normal_lib

tfd = tfp.distributions


def default_mean_field_normal_fn(
is_singular=False,
loc_initializer=tf.keras.initializers.he_normal(),
untransformed_scale_initializer=tf1.initializers.random_normal(
mean=-3.0, stddev=0.1
),
loc_regularizer=tf.keras.regularizers.l2(), # None
untransformed_scale_regularizer=None,
loc_constraint=tf.keras.constraints.UnitNorm(axis=[0, 1, 2, 3]),
untransformed_scale_constraint=None,
):
loc_scale_fn = tfp.layers.default_loc_scale_fn(
is_singular=is_singular,
loc_initializer=loc_initializer,
untransformed_scale_initializer=untransformed_scale_initializer,
loc_regularizer=loc_regularizer,
untransformed_scale_regularizer=untransformed_scale_regularizer,
loc_constraint=loc_constraint,
untransformed_scale_constraint=untransformed_scale_constraint,
)

def _fn(dtype, shape, name, trainable, add_variable_fn):
loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)
if scale is None:
dist = deterministic_lib.Deterministic(loc=loc)
else:
dist = normal_lib.Normal(loc=loc, scale=scale)
batch_ndims = tf2.size(dist.batch_shape_tensor())
return independent_lib.Independent(dist, reinterpreted_batch_ndims=batch_ndims)

return _fn


def divergence_fn_bayesian(prior_std, examples_per_epoch):
def divergence_fn(q, p, _):
log_probs = tfd.LogNormal(0.0, prior_std).log_prob(p.stddev())
out = tfd.kl_divergence(q, p) - tf.reduce_sum(log_probs)
return out / examples_per_epoch

return divergence_fn


def prior_fn_for_bayesian(init_scale_mean=-1, init_scale_std=0.1):
def prior_fn(dtype, shape, name, _, add_variable_fn):
untransformed_scale = add_variable_fn(
name=name + "_untransformed_scale",
shape=(1,),
initializer=tf.compat.v1.initializers.random_normal(
mean=init_scale_mean, stddev=init_scale_std
),
dtype=dtype,
trainable=True,
)
loc = add_variable_fn(
name=name + "_loc",
initializer=tf.keras.initializers.Zeros(),
shape=shape,
dtype=dtype,
trainable=True,
)
scale = 1e-4 + tf.nn.softplus(untransformed_scale)
dist = tfd.Normal(loc=loc, scale=scale)
batch_ndims = tf.size(input=dist.batch_shape_tensor())
return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)

return prior_fn


def normal_prior(prior_std=1.0):
"""Defines normal distributions prior for Bayesian neural network."""

def prior_fn(dtype, shape, name, trainable, add_variable_fn):
dist = tfd.Normal(
loc=tf.zeros(shape, dtype), scale=dtype.as_numpy_dtype((prior_std))
)
batch_ndims = tf.size(input=dist.batch_shape_tensor())
return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)

return prior_fn
223 changes: 223 additions & 0 deletions nobrainer/layers/groupnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# MIT License
#
# Copyright (c) 2019 Somshubra Majumdar
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Retrieved from https://github.com/titu1994/Keras-Group-Normalization/blob/master/group_norm.py

from tensorflow.keras import backend as K
from tensorflow.keras import constraints, initializers, regularizers
from tensorflow.keras.layers import InputSpec, Layer
from tensorflow.keras.utils import get_custom_objects


class GroupNormalization(Layer):
"""Group normalization layer

Group Normalization divides the channels into groups and computes within
each group the mean and variance for normalization. GN"s computation is
independent of batch sizes, and its accuracy is stable in a
wide range of batch sizes.

Arguments
groups: Integer, the number of groups for Group Normalization.
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape
Same shape as input.
References
- [Group Normalization](https://arxiv.org/abs/1803.08494)
"""

def __init__(
self,
groups=4,
axis=-1,
epsilon=1e-5,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs
):
super(GroupNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)

def build(self, input_shape):
dim = input_shape[self.axis]

if dim is None:
raise ValueError(
"Axis " + str(self.axis) + " of "
"input tensor should have a defined dimension "
"but the layer received an input with shape " + str(input_shape) + "."
)

if dim < self.groups:
raise ValueError(
"Number of groups (" + str(self.groups) + ") "
"cannot be more than the number of channels (" + str(dim) + ")."
)

if dim % self.groups != 0:
raise ValueError(
"Number of groups (" + str(self.groups) + ") "
"must be a multiple of the number of channels (" + str(dim) + ")."
)

self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim})
shape = (dim,)

if self.scale:
self.gamma = self.add_weight(
shape=shape,
name="gamma",
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(
shape=shape,
name="beta",
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
)
else:
self.beta = None
self.built = True

def call(self, inputs, **kwargs):
input_shape = K.int_shape(inputs)
tensor_input_shape = K.shape(inputs)

# Prepare broadcasting shape.
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(1, self.groups)

reshape_group_shape = K.shape(inputs)
group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
group_axes[self.axis] = input_shape[self.axis] // self.groups
group_axes.insert(1, self.groups)

# reshape inputs to new group shape
group_shape = [group_axes[0], self.groups] + group_axes[2:]
group_shape = K.stack(group_shape)
inputs = K.reshape(inputs, group_shape)

group_reduction_axes = list(range(len(group_axes)))
group_reduction_axes = group_reduction_axes[2:]

mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True)
variance = K.var(inputs, axis=group_reduction_axes, keepdims=True)

inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))

# prepare broadcast shape
inputs = K.reshape(inputs, group_shape)
outputs = inputs

# In this case we must explicitly broadcast all parameters.
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
outputs = outputs * broadcast_gamma

if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
outputs = outputs + broadcast_beta

outputs = K.reshape(outputs, tensor_input_shape)

return outputs

def get_config(self):
config = {
"groups": self.groups,
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": initializers.serialize(self.beta_initializer),
"gamma_initializer": initializers.serialize(self.gamma_initializer),
"beta_regularizer": regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
"beta_constraint": constraints.serialize(self.beta_constraint),
"gamma_constraint": constraints.serialize(self.gamma_constraint),
}
base_config = super(GroupNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape


get_custom_objects().update({"GroupNormalization": GroupNormalization})


if __name__ == "__main__":
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

ip = Input(shape=(None, None, 4))
x = GroupNormalization(groups=2, axis=-1, epsilon=0.1)(ip)
model = Model(inputs=ip, outputs=x)
model.summary()
Loading