Skip to content

Commit

Permalink
[FIX](layers) Super get_config methods for model serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 25, 2022
1 parent 809aabb commit 5941db2
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 7 deletions.
25 changes: 25 additions & 0 deletions hourglass_tensorflow/layers/batch_norm_relu_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)
# Store Config
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.activation = activation
self.kernel_initializer = kernel_initializer
self.momentum = momentum
self.epsilon = epsilon
# Create Layers
self.batch_norm = layers.BatchNormalization(
axis=-1,
momentum=momentum,
Expand All @@ -40,6 +50,21 @@ def __init__(
name="ReLU",
)

def get_config(self):
return {
**super().get_config(),
**{
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"activation": self.activation,
"kernel_initializer": self.kernel_initializer,
"momentum": self.momentum,
"epsilon": self.epsilon,
},
}

def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = self.batch_norm(inputs, training=training)
x = self.relu(x)
Expand Down
25 changes: 25 additions & 0 deletions hourglass_tensorflow/layers/conv_batch_norm_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)
# Store config
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.activation = activation
self.kernel_initializer = kernel_initializer
self.momentum = momentum
self.epsilon = epsilon
# Create layers
self.batch_norm = layers.BatchNormalization(
axis=-1,
momentum=momentum,
Expand All @@ -40,6 +50,21 @@ def __init__(
name="ReLU",
)

def get_config(self):
return {
**super().get_config(),
**{
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"activation": self.activation,
"kernel_initializer": self.kernel_initializer,
"momentum": self.momentum,
"epsilon": self.epsilon,
},
}

def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = self.conv(inputs)
x = self.batch_norm(x, training=training)
Expand Down
16 changes: 15 additions & 1 deletion hourglass_tensorflow/layers/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)

# Store config
self.output_filters = output_filters
self.momentum = momentum
self.epsilon = epsilon
# Create layers
self.bnrc1 = BatchNormReluConvLayer(
filters=output_filters // 2,
kernel_size=1,
Expand Down Expand Up @@ -49,6 +53,16 @@ def __init__(
trainable=trainable,
)

def get_config(self):
return {
**super().get_config(),
**{
"output_filters": self.output_filters,
"momentum": self.momentum,
"epsilon": self.epsilon,
},
}

def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = self.bnrc1(inputs, training=training)
x = self.bnrc2(x, training=training)
Expand Down
18 changes: 17 additions & 1 deletion hourglass_tensorflow/layers/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)
# Store config
self.input_size = input_size
self.output_size = output_size
self.kernel_size = kernel_size
self.output_filters = output_filters
# Init Computation
self.downsamplings = int(math.log2(input_size // output_size) + 1)
self.layers = []
# Layers
# Create Layers
for i in range(self.downsamplings):
if i == 0:
self.layers.append(
Expand Down Expand Up @@ -68,6 +73,17 @@ def __init__(
)
)

def get_config(self):
return {
**super().get_config(),
**{
"input_size": self.input_size,
"output_size": self.output_size,
"kernel_size": self.kernel_size,
"output_filters": self.output_filters,
},
}

def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = inputs
for layer in self.layers:
Expand Down
15 changes: 13 additions & 2 deletions hourglass_tensorflow/layers/hourglass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)
# Init parameters
# Store Config
self.downsamplings = downsamplings
self.feature_filters = feature_filters
self.output_filters = output_filters
# Init parameters
self.layers = [{} for i in range(self.downsamplings)]
# region Layers
# Create Layers
self._hm_output = ConvBatchNormReluLayer(
filters=output_filters,
kernel_size=1,
Expand Down Expand Up @@ -95,6 +96,16 @@ def __init__(
)
# endregion

def get_config(self):
return {
**super().get_config(),
**{
"downsamplings": self.downsamplings,
"feature_filters": self.feature_filters,
"output_filters": self.output_filters,
},
}

def _recursive_call(self, input_tensor, step, training=True):
step_layers = self.layers[step]
up_1 = step_layers["up_1"](input_tensor, training=training)
Expand Down
16 changes: 15 additions & 1 deletion hourglass_tensorflow/layers/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)
# Layers
# Store config
self.output_filters = output_filters
self.momentum = momentum
self.epsilon = epsilon
# Create Layers
self.conv_block = ConvBlockLayer(
output_filters=output_filters,
momentum=momentum,
Expand All @@ -37,6 +41,16 @@ def __init__(
)
self.add = layers.Add(name="Add")

def get_config(self):
return {
**super().get_config(),
**{
"output_filters": self.output_filters,
"momentum": self.momentum,
"epsilon": self.epsilon,
},
}

def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
return self.add(
[
Expand Down
12 changes: 10 additions & 2 deletions hourglass_tensorflow/layers/skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def __init__(
trainable: bool = True,
) -> None:
super().__init__(name=name, dtype=dtype, dynamic=dynamic, trainable=trainable)

# Store config
self.output_filters = output_filters

# Create Layers
self.conv = layers.Conv2D(
filters=self.output_filters,
kernel_size=1,
Expand All @@ -26,6 +26,14 @@ def __init__(
kernel_initializer="glorot_uniform",
)

def get_config(self):
return {
**super().get_config(),
**{
"output_filters": self.output_filters,
},
}

def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
if inputs.get_shape()[-1] == self.output_filters:
return inputs
Expand Down

0 comments on commit 5941db2

Please sign in to comment.