-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathDockerfile.triton
113 lines (104 loc) · 5.37 KB
/
Dockerfile.triton
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_JAX_TRITON=https://github.com/jax-ml/jax-triton.git#main
ARG SRC_PATH_JAX=/opt/jax
ARG SRC_PATH_JAX_TRITON=/opt/jax-triton
ARG SRC_PATH_TRITON=/opt/triton
ARG SRC_PATH_XLA=/opt/xla
FROM ${BASE_IMAGE} AS base
# Triton setup.py downloads and installs CUDA binaries at specific versions
# hardcoded in the script itself:
# https://github.com/openxla/triton/blob/84f9d9de158fb866fac67970f0f5d323999d9db1/python/setup.py#L373-L393
# Tell Triton to use CUDA binaries from the host container instead. These should be set
# both during the build stage and in the final container.
ENV TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
ENV TRITON_CUOBJDUMP_PATH=/usr/local/cuda/bin/cuobjdump
ENV TRITON_NVDISASM_PATH=/usr/local/cuda/bin/nvdisasm
RUN [ -x "${TRITON_PTXAS_PATH}" ] && [ -x "${TRITON_CUOBJDUMP_PATH}" ] && [ -x "${TRITON_NVDISASM_PATH}" ]
###############################################################################
## Check out LLVM and Triton sources that match XLA. This uses XLA's Bazel
## configuration to get the relevant tag from the openxla/triton fork's
## llvm-head branch and apply XLA's extra patches to it. Also fetches the
## compatible LLVM sources.
###############################################################################
FROM base AS builder
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
RUN <<"EOF" bash -ex
pushd "${SRC_PATH_XLA}"
bazel --output_base=/opt/checkout fetch @triton//:BUILD
rm -rf /root/.cache
EOF
###############################################################################
## Build LLVM
###############################################################################
RUN <<"EOF" bash -ex
mkdir /opt/llvm-build
pushd /opt/llvm-build
pip install ninja && rm -rf ~/.cache/pip
cmake -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_PROJECTS="mlir;llvm" \
-DLLVM_TARGETS_TO_BUILD="host;NVPTX" \
/opt/checkout/external/llvm-raw/llvm
ninja
EOF
###############################################################################
## Build Triton
###############################################################################
RUN <<"EOF" bash -ex
pushd /opt/checkout/external/triton
mkdir dist
# Make sure the wheel is labelled as 3.1 (it is ~HEAD), not 3.0. The upstream branch
# structure seems to have a 3.1 release branch that has not been merged back to main.
sed -i -e 's|version="3\.0\.0"|version="3.1.0"|g' python/setup.py
# Do not compile with -Werror
sed -i -e 's|-Werror||g' CMakeLists.txt
# The LLVM build above does not enable these libraries
sed -i -e 's|\(LLVMAMDGPU.*\)|# \1|g' CMakeLists.txt
# Do not build tests
sed -i -e 's|^\s*add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt
# Avoid building GPUHello.cpp, which does not compile with cl693214397 of OpenXLA Triton
sed -i -e 's|^add_subdirectory(Instrumentation)|# Instrumentation test disabled|' test/lib/CMakeLists.txt
# Avoid error due to forward declaration of Module
sed -i -e '/#include "llvm\/IR\/IRBuilder.h"/a #include "llvm/IR/Module.h"' test/lib/Instrumentation/GPUHello.cpp
# Google has patches that mess with include paths in source files
sed -i -e '/include_directories(${PROJECT_SOURCE_DIR}\/third_party)/a include_directories(${PROJECT_SOURCE_DIR}/third_party/amd/include)' CMakeLists.txt
sed -i -e '/include_directories(${PROJECT_BINARY_DIR}\/third_party)/a include_directories(${PROJECT_BINARY_DIR}/third_party/amd/include)' CMakeLists.txt
sed -i -e '/include_directories(${PROJECT_SOURCE_DIR}\/third_party)/a include_directories(${PROJECT_SOURCE_DIR}/third_party/nvidia/include)' CMakeLists.txt
sed -i -e '/include_directories(${PROJECT_BINARY_DIR}\/third_party)/a include_directories(${PROJECT_BINARY_DIR}/third_party/nvidia/include)' CMakeLists.txt
# Extra patches to Triton maintained in XLA. These are already applied in the working directory.
XLA_TRITON_PATCHES="${SRC_PATH_XLA}/third_party/triton"
# Use clang to match Google etc.
LLVM_INCLUDE_DIRS=/opt/llvm-build/include \
LLVM_LIBRARY_DIR=/opt/llvm-build/lib \
LLVM_SYSPATH=/opt/llvm-build \
TRITON_BUILD_WITH_CLANG_LLD=1 \
pip wheel --verbose --wheel-dir=dist/ python/
# Clean up the wheel build directory so it doesn't end up bloating the container
rm -rf python/build
EOF
###############################################################################
## Copy Triton source/wheel from the builder, checkout JAX-Triton
###############################################################################
FROM base AS mealkit
ARG URLREF_JAX_TRITON
ARG SRC_PATH_JAX_TRITON
ARG SRC_PATH_TRITON
# Get the triton source + wheel from the build step
COPY --from=builder /opt/checkout/external/triton ${SRC_PATH_TRITON}
RUN echo "triton @ file://$(ls ${SRC_PATH_TRITON}/dist/triton-*.whl)" >> /opt/pip-tools.d/requirements-triton.in
# Check out jax-triton
RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_JAX_TRITON} ${SRC_PATH_JAX_TRITON}
echo "-e file://${SRC_PATH_JAX_TRITON}" >> /opt/pip-tools.d/requirements-triton.in
sed -i 's|"jax @ [^"]\+"|"jax"|g;s|"triton-nightly @ [^"]\+"|"triton"|g' ${SRC_PATH_JAX_TRITON}/pyproject.toml
EOF
###############################################################################
## Install accumulated packages from the base image and the previous stage
###############################################################################
FROM mealkit AS final
RUN pip-finalize.sh