Skip to content

Commit

Permalink
test bug [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 22, 2024
1 parent 1d66ceb commit 7ea8589
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 23 deletions.
26 changes: 13 additions & 13 deletions videoswin/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,19 @@ def call(self, x, training=None):
return x

def compute_output_shape(self, input_shape):
if self.downsample is not None:
# TODO: remove tensorflow dependencies.
# GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
output_shape = tf.TensorShape(
[
input_shape[0],
self.depth_pad,
self.height_pad // 2,
self.width_pad // 2,
2 * self.input_dim,
]
)
return output_shape
# if self.downsample is not None:
# # TODO: remove tensorflow dependencies.
# # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
# output_shape = tf.TensorShape(
# [
# input_shape[0],
# self.depth_pad,
# self.height_pad // 2,
# self.width_pad // 2,
# 2 * self.input_dim,
# ]
# )
# return output_shape

return input_shape

Expand Down
10 changes: 0 additions & 10 deletions videoswin/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,6 @@ def call(self, x):

return x

# def compute_output_shape(self, input_shape):
# spatial_dims = [
# (dim - self.patch_size[i]) // self.patch_size[i] + 1
# for i, dim in enumerate(input_shape[1:-1])
# ]
# output_shape = (
# (input_shape[0],) + tuple(spatial_dims) + (self.embed_dim,)
# )
# return output_shape

def get_config(self):
config = super().get_config()
config.update(
Expand Down

0 comments on commit 7ea8589

Please sign in to comment.