Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed Apr 29, 2024
1 parent 11ed2e1 commit 9d21427
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3716,15 +3716,15 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
[dummy_mask, torch.zeros(self.model_tester.seq_length - dummy_mask.size(0))]
)
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
other_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)

if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
other_inputs["noise"] = torch.from_numpy(noise)
processed_inputs["noise"] = torch.from_numpy(noise)

# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
Expand Down

0 comments on commit 9d21427

Please sign in to comment.