Skip to content

Commit

Permalink
test temp [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 22, 2024
1 parent b4dede3 commit 1d66ceb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
32 changes: 16 additions & 16 deletions videoswin/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,22 @@ def call(self, x, training=None):

return x

# def compute_output_shape(self, input_shape):
# if self.downsample is not None:
# # TODO: remove tensorflow dependencies.
# # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
# output_shape = tf.TensorShape(
# [
# input_shape[0],
# self.depth_pad,
# self.height_pad // 2,
# self.width_pad // 2,
# 2 * self.input_dim,
# ]
# )
# return output_shape

# return input_shape
def compute_output_shape(self, input_shape):
if self.downsample is not None:
# TODO: remove tensorflow dependencies.
# GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
output_shape = tf.TensorShape(
[
input_shape[0],
self.depth_pad,
self.height_pad // 2,
self.width_pad // 2,
2 * self.input_dim,
]
)
return output_shape

return input_shape

def get_config(self):
config = super().get_config()
Expand Down
18 changes: 9 additions & 9 deletions videoswin/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def call(self, x):

return x

def compute_output_shape(self, input_shape):
spatial_dims = [
(dim - self.patch_size[i]) // self.patch_size[i] + 1
for i, dim in enumerate(input_shape[1:-1])
]
output_shape = (
(input_shape[0],) + tuple(spatial_dims) + (self.embed_dim,)
)
return output_shape
# def compute_output_shape(self, input_shape):
# spatial_dims = [
# (dim - self.patch_size[i]) // self.patch_size[i] + 1
# for i, dim in enumerate(input_shape[1:-1])
# ]
# output_shape = (
# (input_shape[0],) + tuple(spatial_dims) + (self.embed_dim,)
# )
# return output_shape

def get_config(self):
config = super().get_config()
Expand Down

0 comments on commit 1d66ceb

Please sign in to comment.