Skip to content

Commit

Permalink
♻️ updated implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 23, 2024
1 parent a37ee16 commit 79b114e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 40 deletions.
2 changes: 1 addition & 1 deletion videoswin/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
norm_layer=norm_layer,
downsample=(VideoSwinPatchMerging if (i < num_layers - 1) else None),
downsampling_layer=(VideoSwinPatchMerging if (i < num_layers - 1) else None),
name=f"videoswin_basic_layer_{i + 1}",
)
x = layer(x)
Expand Down
42 changes: 15 additions & 27 deletions videoswin/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class VideoSwinBasicLayer(keras.Model):
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (keras.layers, optional): Normalization layer. Default: LayerNormalization
downsample (keras.layers | None, optional): Downsample layer at the end of the layer. Default: None
downsampling_layer (keras.layers | None, optional): Downsample layer at the end of the layer. Default: None
References:
- [Video Swin Transformer](https://arxiv.org/abs/2106.13230)
Expand All @@ -41,7 +41,7 @@ def __init__(
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=None,
downsample=None,
downsampling_layer=None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -57,26 +57,26 @@ def __init__(
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.norm_layer = norm_layer
self.downsample = downsample
self.downsampling_layer = downsampling_layer

def _compute_dim_padded(self, input_dim, window_dim_size):
def __compute_dim_padded(self, input_dim, window_dim_size):
input_dim = ops.cast(input_dim, dtype="float32")
window_dim_size = ops.cast(window_dim_size, dtype="float32")
return ops.cast(
ops.ceil(input_dim / window_dim_size) * window_dim_size, "int32"
ops.ceil(input_dim / window_dim_size) * window_dim_size, dtype="int32"
)

def build(self, input_shape):
self.window_size, self.shift_size = get_window_size(
input_shape[1:-1], self.window_size, self.shift_size
)
self.depth_pad = self._compute_dim_padded(input_shape[1], self.window_size[0])
self.height_pad = self._compute_dim_padded(input_shape[2], self.window_size[1])
self.width_pad = self._compute_dim_padded(input_shape[3], self.window_size[2])
depth_pad = self.__compute_dim_padded(input_shape[1], self.window_size[0])
height_pad = self.__compute_dim_padded(input_shape[2], self.window_size[1])
width_pad = self.__compute_dim_padded(input_shape[3], self.window_size[2])
self.attn_mask = compute_mask(
self.depth_pad,
self.height_pad,
self.width_pad,
depth_pad,
height_pad,
width_pad,
self.window_size,
self.shift_size,
)
Expand All @@ -103,8 +103,8 @@ def build(self, input_shape):
for i in range(self.depth)
]

if self.downsample is not None:
self.downsample = self.downsample(
if self.downsampling_layer is not None:
self.downsample = self.downsampling_layer(
input_dim=self.input_dim, norm_layer=self.norm_layer
)
self.downsample.build(input_shape)
Expand All @@ -129,25 +129,13 @@ def call(self, x, training=None):

x = ops.reshape(x, [batch_size, depth, height, width, channel])

if self.downsample is not None:
if self.downsampling_layer is not None:
x = self.downsample(x)

return x

def compute_output_shape(self, input_shape):
if self.downsample is not None:
# TODO: remove tensorflow dependencies.
# GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
# output_shape = tf.TensorShape(
# [
# input_shape[0],
# self.depth_pad,
# self.height_pad // 2,
# self.width_pad // 2,
# 2 * self.input_dim,
# ]
# )

if self.downsampling_layer is not None:
output_shape = self.downsample.compute_output_shape(input_shape)
return output_shape

Expand Down
19 changes: 9 additions & 10 deletions videoswin/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,37 @@ def __init__(self, patch_size=(2, 4, 4), embed_dim=96, norm_layer=None, **kwargs
self.embed_dim = embed_dim
self.norm_layer = norm_layer

def _compute_padding(self, dim, patch_size):
def __compute_padding(self, dim, patch_size):
pad_amount = patch_size - (dim % patch_size)
return [0, pad_amount if pad_amount != patch_size else 0]

def build(self, input_shape):
self.pads = [
[0, 0],
self._compute_padding(input_shape[1], self.patch_size[0]),
self._compute_padding(input_shape[2], self.patch_size[1]),
self._compute_padding(input_shape[3], self.patch_size[2]),
self.__compute_padding(input_shape[1], self.patch_size[0]),
self.__compute_padding(input_shape[2], self.patch_size[1]),
self.__compute_padding(input_shape[3], self.patch_size[2]),
[0, 0],
]

if self.norm_layer is not None:
self.norm = self.norm_layer(axis=-1, epsilon=1e-5, name="embed_norm")
self.norm.build((None, None, None, None, self.embed_dim))

self.proj = layers.Conv3D(
self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
name="embed_proj",
)
self.proj.build((None, None, None, None, input_shape[-1]))

self.norm = None
if self.norm_layer is not None:
self.norm = self.norm_layer(axis=-1, epsilon=1e-5, name="embed_norm")
self.norm.build((None, None, None, None, self.embed_dim))
self.built = True

def call(self, x):
x = ops.pad(x, self.pads)
x = self.proj(x)

if self.norm is not None:
if self.norm_layer is not None:
x = self.norm(x)

return x
Expand Down
3 changes: 1 addition & 2 deletions videoswin/layers/patch_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def build(self, input_shape):
self.reduction = layers.Dense(2 * self.input_dim, use_bias=False)
self.reduction.build((batch_size, depth, height // 2, width // 2, 4 * channel))

self.norm = None
if self.norm_layer is not None:
self.norm = self.norm_layer(axis=-1, epsilon=1e-5)
self.norm.build((batch_size, depth, height // 2, width // 2, 4 * channel))
Expand All @@ -52,7 +51,7 @@ def call(self, x):
x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
x = ops.concatenate([x0, x1, x2, x3], axis=-1) # B D H/2 W/2 4*C

if self.norm is not None:
if self.norm_layer is not None:
x = self.norm(x)

x = self.reduction(x)
Expand Down

0 comments on commit 79b114e

Please sign in to comment.