Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference UTs check for trition support from accelerator #6782

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

raza-sikander
Copy link
Contributor

Instead of checking if installed or not check for support. Skip if not supported.

@@ -27,8 +27,8 @@ def ref_torch_attention(q, k, v, mask, sm_scale):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("use_flash", [True, False])
def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @raza-sikander - do you think we can extrapolate this to replace all instances of HAS_TRITON?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@loadams Yes.
This replacing would help cover the case where triton is installed on system but its not supported by device, test would still run as the triton has been installed and fail.
So the ideal case would be to check if it is supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants