From 1d66cebcab50d55f1235476061e19fdd323fb9cb Mon Sep 17 00:00:00 2001 From: innat Date: Sat, 23 Mar 2024 00:39:46 +0600 Subject: [PATCH] test temp [skip ci] --- videoswin/blocks/basic.py | 32 ++++++++++++++++---------------- videoswin/layers/patch_embed.py | 18 +++++++++--------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/videoswin/blocks/basic.py b/videoswin/blocks/basic.py index 536f8ab..de7d6ae 100644 --- a/videoswin/blocks/basic.py +++ b/videoswin/blocks/basic.py @@ -134,22 +134,22 @@ def call(self, x, training=None): return x - # def compute_output_shape(self, input_shape): - # if self.downsample is not None: - # # TODO: remove tensorflow dependencies. - # # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501 - # output_shape = tf.TensorShape( - # [ - # input_shape[0], - # self.depth_pad, - # self.height_pad // 2, - # self.width_pad // 2, - # 2 * self.input_dim, - # ] - # ) - # return output_shape - - # return input_shape + def compute_output_shape(self, input_shape): + if self.downsample is not None: + # TODO: remove tensorflow dependencies. + # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501 + output_shape = tf.TensorShape( + [ + input_shape[0], + self.depth_pad, + self.height_pad // 2, + self.width_pad // 2, + 2 * self.input_dim, + ] + ) + return output_shape + + return input_shape def get_config(self): config = super().get_config() diff --git a/videoswin/layers/patch_embed.py b/videoswin/layers/patch_embed.py index df974a2..1f6f623 100644 --- a/videoswin/layers/patch_embed.py +++ b/videoswin/layers/patch_embed.py @@ -62,15 +62,15 @@ def call(self, x): return x - def compute_output_shape(self, input_shape): - spatial_dims = [ - (dim - self.patch_size[i]) // self.patch_size[i] + 1 - for i, dim in enumerate(input_shape[1:-1]) - ] - output_shape = ( - (input_shape[0],) + tuple(spatial_dims) + (self.embed_dim,) - ) - return output_shape + # def compute_output_shape(self, input_shape): + # spatial_dims = [ + # (dim - self.patch_size[i]) // self.patch_size[i] + 1 + # for i, dim in enumerate(input_shape[1:-1]) + # ] + # output_shape = ( + # (input_shape[0],) + tuple(spatial_dims) + (self.embed_dim,) + # ) + # return output_shape def get_config(self): config = super().get_config()