Skip to content

Commit

Permalink
[ADD][FIX] call argument training defaults to True
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 21, 2022
1 parent 79a05ec commit 1252888
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions hourglass_tensorflow/layers/batch_norm_relu_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(
self,
filters: int,
kernel_size: int,
strides: int,
strides: int = 1,
padding: str = "same",
activation: str = None,
kernel_initializer: str = "glorot_uniform",
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
name="ReLU",
)

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = self.batch_norm(inputs, training=training)
x = self.relu(x)
x = self.conv(x)
Expand Down
4 changes: 2 additions & 2 deletions hourglass_tensorflow/layers/conv_batch_norm_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
self,
filters: int,
kernel_size: int,
strides: int,
strides: int = 1,
padding: str = "same",
activation: str = None,
kernel_initializer: str = "glorot_uniform",
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
name="ReLU",
)

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = self.conv(inputs)
x = self.batch_norm(x, training=training)
x = self.relu(x)
Expand Down
2 changes: 1 addition & 1 deletion hourglass_tensorflow/layers/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
trainable=trainable,
)

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = self.bnrc1(inputs, training=training)
x = self.bnrc2(x, training=training)
x = self.bnrc3(x, training=training)
Expand Down
2 changes: 1 addition & 1 deletion hourglass_tensorflow/layers/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
)
)

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
x = inputs
for layer in self.layers:
x = layer(x, training=training)
Expand Down
2 changes: 1 addition & 1 deletion hourglass_tensorflow/layers/hourglass.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
)
# endregion

def _recursive_call(self, input_tensor, step, training=False):
def _recursive_call(self, input_tensor, step, training=True):
step_layers = self.layers[step]
up_1 = step_layers["up_1"](input_tensor, training=training)
low_ = step_layers["low_"](input_tensor, training=training)
Expand Down
2 changes: 1 addition & 1 deletion hourglass_tensorflow/layers/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
)
self.add = layers.Add(name="Add")

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
return self.add(
[
self.conv_block(inputs, training=training),
Expand Down
2 changes: 1 addition & 1 deletion hourglass_tensorflow/layers/skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
kernel_initializer="glorot_uniform",
)

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
if inputs.get_shape()[-1] == self.output_filters:
return inputs
else:
Expand Down
2 changes: 1 addition & 1 deletion hourglass_tensorflow/models/hourglass.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
for i in range(stages)
]

def call(self, inputs: tf.Tensor, training=False):
def call(self, inputs: tf.Tensor, training=True):
x = self.downsampling(inputs)
outputs = []
for layer in self.hourglasses:
Expand Down

0 comments on commit 1252888

Please sign in to comment.