Skip to content

Commit

Permalink
Don't install Triton nightly separately in CI
Browse files Browse the repository at this point in the history
Instead count on PyTorch nightlies pulling in a recent enough (and well tested!) version of Triton

We also disable Triton in some components on V100, because we started seeing failures on recent Triton versions. This is what these failures looked like (the first 20 ones): P967812823.

ghstack-source-id: 612d5ea3406be3729f4be98b0a583759b3048ae3
Pull Request resolved: fairinternal/xformers#985

__original_commit__ = fairinternal/xformers@d654838
  • Loading branch information
lw authored and xFormers Bot committed Jan 5, 2024
1 parent 042abc8 commit 6600003
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 41 deletions.
2 changes: 0 additions & 2 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,3 @@ scipy

# Dependency for fused layers, optional
cmake
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
triton-nightly<=2.1.0.post20231125000000
7 changes: 2 additions & 5 deletions tests/test_core_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from xformers.components.attention.core import scaled_dot_product_attention

if _is_triton_available():
from xformers.triton.utils import gpu_capabilities_older_than_70
from xformers.triton.utils import gpu_capabilities_older_than_80

_is_blocksparse_available = (
_is_triton_available() and not gpu_capabilities_older_than_70()
_is_triton_available() and not gpu_capabilities_older_than_80()
)


Expand Down Expand Up @@ -166,9 +166,6 @@ def test_switch_blocksparse(device, data_type):
# Mask with causal flag
m_att_mask = AttentionMask.make_causal(s, s, device, dtype=a.dtype)

def kernel():
return scaled_dot_product_attention(a, a, a, m_att_mask)

# Check that a switch to blocksparse is only triggered by causal flag
with torch.cuda.amp.autocast():
r_custom = scaled_dot_product_attention(a, a, a, m_custom)
Expand Down
9 changes: 2 additions & 7 deletions tests/test_triton_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from xformers.components import MultiHeadDispatch
from xformers.components.attention import build_attention
from xformers.components.attention.attention_patterns import block_sparsify_tensor
from xformers.triton.utils import get_current_cuda_device


def catch_oor(fn):
Expand Down Expand Up @@ -45,9 +44,9 @@ def fn_and_catch_oor(*args, **kwargs):
from triton.ops.blocksparse import softmax as blocksparse_softmax

from xformers.components.attention import BlockSparseAttention
from xformers.triton.utils import gpu_capabilities_older_than_70
from xformers.triton.utils import gpu_capabilities_older_than_80

_triton_available = not gpu_capabilities_older_than_70()
_triton_available = not gpu_capabilities_older_than_80()
_matmul_types = ["sdd", "dsd", "dds"]
except (ImportError, ModuleNotFoundError) as e:
import logging
Expand All @@ -64,10 +63,6 @@ def mask_tensor(x, mask, block, value=0):


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.skipif(
not _triton_available or get_current_cuda_device() == "T4",
reason="FIXME - blocksparse matmuls are slightly off on T4s",
)
@pytest.mark.parametrize("MODE", _matmul_types)
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
Expand Down
6 changes: 3 additions & 3 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore

from xformers.triton.utils import gpu_capabilities_older_than_70
from xformers.triton.utils import gpu_capabilities_older_than_80

# Blocksparse requires Tensor cores
if gpu_capabilities_older_than_70():
# Blocksparse requires Tensor cores, but we also disable it on V100 because of Triton issues
if gpu_capabilities_older_than_80():
logger.warning(
"Blocksparse is not available: the current GPU does not expose Tensor cores"
)
Expand Down
4 changes: 2 additions & 2 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

if _is_triton_available():
from xformers.triton.softmax import softmax as triton_softmax
from xformers.triton.utils import gpu_capabilities_older_than_70
from xformers.triton.utils import gpu_capabilities_older_than_80

_is_blocksparse_available = (
_is_triton_available() and not gpu_capabilities_older_than_70()
_is_triton_available() and not gpu_capabilities_older_than_80()
)

if _is_blocksparse_available:
Expand Down
41 changes: 19 additions & 22 deletions xformers/triton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,34 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional
from typing import Optional, Tuple

import torch

logger = logging.getLogger("xformers")


_gpu_is_old: Optional[bool] = None
_oldest_gpu: Optional[Tuple[int, int]] = None


def gpu_capabilities_older_than_70() -> bool:
"""Return True if the GPU's compute capability is older than SM70."""
global _gpu_is_old
if _gpu_is_old is None:
for i in range(torch.cuda.device_count()):
major, _ = torch.cuda.get_device_capability(f"cuda:{i}")
if major < 7:
_gpu_is_old = True
if _gpu_is_old is None:
_gpu_is_old = False
return _gpu_is_old

def _get_oldest_gpu() -> Tuple[int, int]:
global _oldest_gpu
if _oldest_gpu is None:
_oldest_gpu = min(
(
torch.cuda.get_device_capability(f"cuda:{i}")
for i in range(torch.cuda.device_count())
),
default=(0, 0),
)
return _oldest_gpu

SUPPORTED_CUDA_DEVICES = ["V100", "A100", "T4"]

def gpu_capabilities_older_than_70() -> bool:
"""Return True if the GPU's compute capability is older than SM70."""
return _get_oldest_gpu() < (7, 0)

def get_current_cuda_device():
current_device = str(torch.cuda.get_device_properties(torch.cuda.current_device()))
for device_str in SUPPORTED_CUDA_DEVICES:
if current_device.find(device_str) > 0:
return device_str

logger.warning("Unsupported device, Triton code generation may fail")
return "P100" # default to an old GPU
def gpu_capabilities_older_than_80() -> bool:
"""Return True if the GPU's compute capability is older than SM80."""
return _get_oldest_gpu() < (8, 0)

0 comments on commit 6600003

Please sign in to comment.