diff --git a/tests/integration/test_match_huggingface.py b/tests/integration/test_match_huggingface.py index ef9d39e01..95ac80b35 100644 --- a/tests/integration/test_match_huggingface.py +++ b/tests/integration/test_match_huggingface.py @@ -42,4 +42,4 @@ def test_compare_huggingface_attention_match_local_implementation(self, model_na ) hf_out, _, _ = hf_model.transformer.h[layer_n].attn(hidden_states=input) - # assert torch.sum(tl_out == hf_out) == math.prod(tl_out.shape) + assert torch.sum(tl_out == hf_out) == math.prod(tl_out.shape)