Skip to content

fp8 backward

fp8 backward #558

Workflow file for this run

name: AMD Perf Kernel Tests
on:
workflow_dispatch:
pull_request:
branches: [main_perf]
push:
branches: [main_perf, rebase/main_perf]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
Integration-Tests-AMD:
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: [linux-mi300-gpu-1]
fail-fast: false # disables failing the entire job when one matrix entry fails
container:
image: rocm/pytorch:rocm6.3.2_ubuntu22.04_py3.10_pytorch_release_2.4.0
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Show Device Info
run: |
rocminfo | grep gfx
- name: Uninstall Triton
run: |
pip uninstall -y triton
rm -rf ~/.triton
rm -rf ./triton/python/build
- name: Install Triton
run: |
pip install triton==3.2.0
- name: Show Triton version
run: |
pip show triton
- name: Build
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
- name: Flash Attention Tests using Pytorch reference implementation
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
FLASH_ATTENTION_TRITON_AMD_REF=1 pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention Tests
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
- name: Install dependencies for bench and misc
run: |
pip install matplotlib pandas pytest
# FIXME: run the full suite
- name: AMD Internal Tests
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_fp8 flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_varlen_fp8
- name: AMD Bench
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py