diff --git a/test/test_layers.py b/test/test_layers.py index 19e9204..28ced6d 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -1,4 +1,5 @@ from base import TestCase +import keras from keras import ops from videoswin.layers import ( @@ -13,9 +14,9 @@ def test_patch_embedding_compute_output_shape(self): patch_embedding_model = VideoSwinPatchingAndEmbedding( patch_size=(2, 4, 4), embed_dim=96, norm_layer=None ) - input_shape = (None, 16, 32, 32, 3) - output_shape = patch_embedding_model.compute_output_shape(input_shape) - expected_output_shape = (None, 8, 8, 8, 96) + input_array = keras.random.normal(shape=(1, 16, 32, 32, 3)) + output_shape = patch_embedding_model(input_array).shape + expected_output_shape = (1, 8, 8, 8, 96) self.assertEqual(output_shape, expected_output_shape) def test_patch_embedding_get_config(self):