Skip to content

Commit

Permalink
RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/m…
Browse files Browse the repository at this point in the history
…odels/deit/ tests/models/videomae/ passes
  • Loading branch information
Sebastien Ehrhardt committed Apr 30, 2024
1 parent 61a97df commit 547f6c4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
4 changes: 4 additions & 0 deletions tests/models/deit/test_modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
num_labels=3,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -103,6 +104,9 @@ def __init__(
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 2
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand Down
1 change: 1 addition & 0 deletions tests/models/deit/test_modeling_tf_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class TFDeiTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
has_attentions = False

def setUp(self):
self.model_tester = TFDeiTModelTester(self)
Expand Down
7 changes: 5 additions & 2 deletions tests/models/videomae/test_modeling_videomae.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# important: each video needs to have the same number of masked patches
# hence we define a single mask, which we then repeat for each example in the batch
mask = torch.ones((self.model_tester.num_masks,))
mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))])
bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool()
mask = torch.cat(
[mask, torch.zeros(self.model_tester.seq_length - mask.size(0))]
)
batch_size = inputs_dict["pixel_values"].shape[0]
bool_masked_pos = mask.expand(batch_size, -1).bool()
inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device)

if return_labels:
Expand Down
46 changes: 40 additions & 6 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3583,6 +3583,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# FIXME: we deactivate boolean mask because pretrained models
# will not load the mask token
if "use_mask_token" in inspect.signature(model_class).parameters:
deactivate_mask = True
else:
deactivate_mask = False

is_encoder_decoder = model.config.is_encoder_decoder

Expand Down Expand Up @@ -3710,13 +3716,41 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask

if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
dummy_mask = torch.cat(
[dummy_mask, torch.zeros(self.model_tester.seq_length - dummy_mask.size(0))]
if (
"bool_masked_pos"
in inspect.signature(model_eager.forward).parameters
) and not deactivate_mask:
dummy_mask = torch.ones(
(self.model_tester.num_masks,)
)

# In case of additional token (like class) we define a custome `mask_length`
if hasattr(self.model_tester, "mask_length"):
dummy_mask = torch.cat(
[
dummy_mask,
torch.zeros(
self.model_tester.mask_length
- dummy_mask.size(0)
),
]
)
else:
dummy_mask = torch.cat(
[
dummy_mask,
torch.zeros(
self.model_tester.seq_length
- dummy_mask.size(0)
),
]
)
dummy_bool_masked_pos = dummy_mask.expand(
batch_size, -1
).bool()
processed_inputs["bool_masked_pos"] = (
dummy_bool_masked_pos.to(torch_device)
)
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)

if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
Expand Down

0 comments on commit 547f6c4

Please sign in to comment.