diff --git a/videoswin/blocks/basic.py b/videoswin/blocks/basic.py index de7d6ae..53272b9 100644 --- a/videoswin/blocks/basic.py +++ b/videoswin/blocks/basic.py @@ -135,19 +135,19 @@ def call(self, x, training=None): return x def compute_output_shape(self, input_shape): - if self.downsample is not None: - # TODO: remove tensorflow dependencies. - # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501 - output_shape = tf.TensorShape( - [ - input_shape[0], - self.depth_pad, - self.height_pad // 2, - self.width_pad // 2, - 2 * self.input_dim, - ] - ) - return output_shape + # if self.downsample is not None: + # # TODO: remove tensorflow dependencies. + # # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501 + # output_shape = tf.TensorShape( + # [ + # input_shape[0], + # self.depth_pad, + # self.height_pad // 2, + # self.width_pad // 2, + # 2 * self.input_dim, + # ] + # ) + # return output_shape return input_shape diff --git a/videoswin/layers/patch_embed.py b/videoswin/layers/patch_embed.py index 1f6f623..20e9a63 100644 --- a/videoswin/layers/patch_embed.py +++ b/videoswin/layers/patch_embed.py @@ -62,16 +62,6 @@ def call(self, x): return x - # def compute_output_shape(self, input_shape): - # spatial_dims = [ - # (dim - self.patch_size[i]) // self.patch_size[i] + 1 - # for i, dim in enumerate(input_shape[1:-1]) - # ] - # output_shape = ( - # (input_shape[0],) + tuple(spatial_dims) + (self.embed_dim,) - # ) - # return output_shape - def get_config(self): config = super().get_config() config.update(