From 58b30bb6157111b1a86c37f1937c66bad617e710 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sat, 14 Sep 2024 07:54:34 +0800 Subject: [PATCH] [Hardware][intel GPU] bump up ipex version to 2.3 (#8365) Co-authored-by: Yan Ma --- Dockerfile.xpu | 12 ++- requirements-xpu.txt | 9 ++- vllm/_ipex_ops.py | 98 +++++++----------------- vllm/attention/backends/ipex_attn.py | 8 +- vllm/model_executor/layers/activation.py | 15 ++-- vllm/model_executor/layers/layernorm.py | 5 +- 6 files changed, 60 insertions(+), 87 deletions(-) diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 321da98cf6c89..50bbd8f7dad87 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,15 +1,23 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04 +FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg RUN apt-get update -y \ && apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 + +RUN git clone https://github.com/intel/pti-gpu && \ + cd pti-gpu/sdk && \ + mkdir build && \ + cd build && \ + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \ + make -j && \ + cmake --install . --config Release --prefix "/usr/local" + COPY ./ /workspace/vllm WORKDIR /workspace/vllm diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 48d899ec70eda..f07211b48b68d 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -3,9 +3,10 @@ setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed. -torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl -intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl -oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl +torch == 2.3.1+cxx11.abi +intel-extension-for-pytorch == 2.3.110+xpu +oneccl_bind_pt == 2.3.100+xpu -triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +triton-xpu == 3.0.0b2 +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 2156f6b18adb6..31fcc4c3256a8 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -27,29 +27,27 @@ def _reshape_activation_tensor( @staticmethod def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.silu_mul(x1, x2, out) + ipex.llm.functional.silu_and_mul(x, out) @staticmethod def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.gelu_mul(x1, x2, out, "none") + ipex.llm.functional.gelu_and_mul(x, out) @staticmethod def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") + ipex.llm.functional.gelu_and_mul(x, out) @staticmethod - def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(torch.nn.functional.gelu(x)) + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) @staticmethod - def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(torch.nn.functional.gelu(x)) + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) - # TODO add implementation of gelu_quick here - # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) @staticmethod def paged_attention_v1( @@ -160,29 +158,10 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] is_neox: bool, ) -> None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - - rotary_dim = cos_sin_cache.size(1) - query = query.view(*query.shape[:-1], -1, head_size) - key = key.view(*key.shape[:-1], -1, head_size) - - query_rot = query[..., :rotary_dim] - key_rot = key[..., :rotary_dim] - - cos_sin = cos_sin_cache[positions.long()] - cos, sin = cos_sin.chunk(2, dim=-1) - - if is_neox: - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, - rotary_dim, is_neox, positions) + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) @staticmethod def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -190,37 +169,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions) - rotary_dim = cos_sin_cache.size(1) - query = query.view(*query.shape[:-1], -1, head_size) - key = key.view(*key.shape[:-1], -1, head_size) - - query_rot = query[..., :rotary_dim] - key_rot = key[..., :rotary_dim] - - cos_sin = cos_sin_cache[torch.add(positions, - cos_sin_cache_offsets).long()] - cos, sin = cos_sin.chunk(2, dim=-1) - - if is_neox: - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, - rotary_dim, is_neox, positions) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) @staticmethod - def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: - tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) - out.copy_(tmp) + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) @staticmethod def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, @@ -246,11 +203,14 @@ def varlen_attention( return_softmax: bool, gen_: torch.Generator, ) -> None: - ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, - seqlen_k, max_seqlen_q, - max_seqlen_k, pdropout, - softmax_scale, zero_tensors, - is_causal, return_softmax, gen_) + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), + max_seqlen_q, max_seqlen_k, + pdropout, softmax_scale, + zero_tensors, is_causal, + return_softmax, gen_) @staticmethod def reshape_and_cache( diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 64d60e4e47e48..113a2788eacd3 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -49,14 +49,18 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 4c14fe476ee4a..43056786d35c9 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -114,9 +114,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_new(out, x) - return out + return ops.gelu_new(x) class FastGELU(CustomOp): @@ -136,9 +134,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_fast(out, x) - return out + return ops.gelu_fast(x) class QuickGELU(CustomOp): @@ -155,6 +151,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_quick(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + out = torch.empty_like(x) + ops.gelu_quick(out, x) + return out + # TODO implement forward_xpu for QuickGELU # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7d342714f129e..1221a6518fea0 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -82,14 +82,11 @@ def forward_xpu( self.variance_epsilon, ) return x, residual - out = torch.empty_like(x) - ops.rms_norm( - out, + return ops.rms_norm( x, self.weight.data, self.variance_epsilon, ) - return out def forward_npu( self,