Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add ALiBi support for ggml_soft_max_ext #5488

Merged
merged 18 commits into from
Feb 17, 2024
Merged

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Feb 14, 2024

ref: #3470

Missing backends currently generate compile warnings. For some time, we will fallback to the old implementation of ggml_alibi but this will be removed at some point, so support should be added in all backends when possible.

Additional changes:

  • Support multi-sequence ALiBi via KQ_pos tensor. It provides the token position for each KV cache cell
  • Remove F16 CUDA soft_max kernels

ggml.c Outdated Show resolved Hide resolved
@slaren
Copy link
Collaborator

slaren commented Feb 14, 2024

Isn't this softmax part of flash attention? Are we giving up on that?

@ggerganov
Copy link
Owner Author

ggml_soft_max_ext is needed even with ggml_flash_attention_ext because not all backends will implement FA. This way they can fallback to ggml_soft_max_ext that fuses scaling, masking and ALiBi slope

Regarding the FA branch - I've hit a roadblock with the CUDA implementation, and don't want to merge just Metal support. But once this is resolved, the ggml_flash_attention_ext will be extended with a similar ALiBi support

@ggerganov
Copy link
Owner Author

ggerganov commented Feb 14, 2024

Is there significant benefit from the F16 CUDA soft_max implementation?

llama.cpp/ggml-cuda.cu

Lines 9132 to 9148 in a0f8a93

#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
#ifdef GGML_CUDA_F16
const bool use_f16_soft_max = true;
#else
const bool use_f16_soft_max = false;
#endif // GGML_CUDA_F16
#else
const bool use_f16_soft_max = false;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
} else {
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
}

I think it is currently broken on master - enable the mask in the test-backend-ops and build with LLAMA_CUDA_F16=1.

Also, while on this topic, I see we have custom template specifications for a set of sizes in CUDA soft_max:

llama.cpp/ggml-cuda.cu

Lines 7673 to 7706 in a0f8a93

if (shmem < g_device_caps[g_main_device].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
}

How much do we gain from these? I assume there is some benefit for prompt processing, because for TG we almost never hit these branches. cc @JohannesGaessler Do you have some numbers handy?

@JohannesGaessler
Copy link
Collaborator

ggml_soft_max_ext is needed even with ggml_flash_attention_ext because not all backends will implement FA. This way they can fallback to ggml_soft_max_ext that fuses scaling, masking and ALiBi slope

Do consider that if you ever want to implement sampling on the GPU softmax will also be needed.

Regarding the FA branch - I've hit a roadblock with the CUDA implementation, and don't want to merge just Metal support. But once this is resolved, the ggml_flash_attention_ext will be extended with a similar ALiBi support

As I said before, my intention with FA was to let Steward/FSSRepo handle it (with me advising) since long term I would like to have more devs. But recently there hasn't really been progress from his side so maybe I'll just do it myself in the next few weeks.

Is there significant benefit from the F16 CUDA soft_max implementation?

A quick low sample size test where I enabled FP16 soft max and nothing else via a simple patch:

GPU Model Batch size Test t/s master t/s patch Speedup
RTX 3090 llama 7B Q4_0 1 pp4096 110.83 111.15 1.00
RTX 3090 llama 7B Q4_0 2 pp4096 209.82 210.48 1.00
RTX 3090 llama 7B Q4_0 4 pp4096 343.02 346.16 1.01
RTX 3090 llama 7B Q4_0 8 pp4096 454.09 457.24 1.01
RTX 3090 llama 7B Q4_0 16 pp4096 462.47 465.90 1.01
RTX 3090 llama 7B Q4_0 32 pp4096 584.72 590.74 1.01
RTX 3090 llama 7B Q4_0 64 pp4096 1182.61 1192.27 1.01
RTX 3090 llama 7B Q4_0 128 pp4096 1906.49 1914.28 1.00
RTX 3090 llama 7B Q4_0 256 pp4096 2619.82 2670.42 1.02
RTX 3090 llama 7B Q4_0 512 pp4096 3017.57 3073.57 1.02
RTX 3090 llama 7B Q4_0 1024 pp4096 3334.44 3383.46 1.01
RTX 3090 llama 7B Q4_0 2048 pp4096 3303.83 3317.45 1.00
RTX 3090 llama 7B Q4_0 4096 pp4096 3023.11 3037.88 1.00

The FP16 kernel itself is ~1.4x faster than the FP32 kernel.

I think it is currently broken on master - enable the mask in the test-backend-ops and build with LLAMA_CUDA_F16=1.

This is an issue with how the tests are set up. The worst precision I get for FP16 softmax is NMSE = 4e-6. This is worse than the threshold of 1e-7 but still completely negligible. In the first place, NMSE does not accurately reflect the effects of numerical precision for softmax. NMSE is dominated by the effects on values that are ~0 after softmax but those values are going to be ignored anyways and even large relative differences like 1e-9 <-> 1e-10 on those values are completely meaningless.

Also notice how NMSE increases with the input size. I didn't analyze the code in detail but what I think is happening is that the number of values that are ~0 after softmax increases. And because the numerical stability for those values is worse worse NMSE increases. But this does not make a practical difference.

@JohannesGaessler
Copy link
Collaborator

How much do we gain from these?

From the templating in particular I think the difference was ~10% in terms of the kernel runtime.

@ggerganov
Copy link
Owner Author

On master if I make the following patch:

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 9af8517d..92adf4fb 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -2091,7 +2091,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         for (int n = 0; n < 10; ++n) {
             int64_t ne0 = dist_ne0(rng);
             int64_t ne1 = dist_ne1(rng);
-            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
+            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, 0.1f));
         }
 
         exponent <<= 1;

Some of the tests fail with inf NMSE:

  SOFT_MAX(type=f32,ne=[55701,6,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = 0.000001874 > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[57757,27,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = 0.000003197 > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[50919,30,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = 0.000002142 > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[43571,36,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = 0.000000810 > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[37455,9,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = 0.000001378 > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[48671,44,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = 0.000001405 > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[118835,28,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[113968,16,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[74405,27,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[85901,30,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[99490,22,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[82500,19,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[91293,23,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[96725,20,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[83840,4,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL
  SOFT_MAX(type=f32,ne=[89768,13,1,1],scale=0.100000,mask=0): [SOFT_MAX] NMSE = inf > 0.000000100 FAIL

They don't fail with the default scale factor of 1.0f (i.e. the NSME is still larger than the threshold, but it's not inf).

I'm thinking we should probably drop the F16 variant of soft_max for now - does not seem really worth to me. Long-term when ggml supports F16 activations, we can revisit

@JohannesGaessler
Copy link
Collaborator

I'm thinking we should probably drop the F16 variant of soft_max for now - does not seem really worth to me. Long-term when ggml supports F16 activations, we can revisit

Would be fine with me.

@ggerganov ggerganov marked this pull request as ready for review February 15, 2024 12:41
@slaren
Copy link
Collaborator

slaren commented Feb 15, 2024

This seems to cause a significant performance regression in alibi models:

GPU Model Test t/s master t/s gg/refactor-alibi Speedup
RTX 3090 Ti refact 1B Q4_0 pp512 9836.03 4157.76 0.42
RTX 3090 Ti refact 1B Q4_0 tg128 211.99 219.80 1.04

This can probably be fixed, but I don't really understand what is the goal of this change. I doesn't seem to improve performance, and it will require changes in every backend, it will become obsolete after the introduction of flash attention, and in my opinion the code is not really any simpler, it replaces a very simple ggml_alibi operation with a much more complex fused operation.

@ggerganov
Copy link
Owner Author

I suppose I did something wrong in the CUDA implementation. Here are the results with Metal:

model params backend ngl test master t/s PR t/s speedup
refact 1B F16 1.59 B Metal 99 pp 512 3564.96 ± 57.21 4873.12 ± 110.9 1.37
refact 1B F16 1.59 B Metal 99 tg 128 108.03 ± 0.03 113.72 ± 0.06 1.05
mpt 7B F16 (guessed) 6.86 B Metal 99 pp 512 1307.48 ± 9.86 1454.83 ± 2.39 1.11
mpt 7B F16 (guessed) 6.86 B Metal 99 tg 128 40.10 ± 0.03 40.84 ± 0.02 1.02

The ggml_alibi() on master still needs a similar update across all backends because it currently assumes the KV cache consists of a single sequence (i.e. the token position is mapped to the index i of the KV cell). Basically, all ALiBi models on master work only for single-sequence generation.

It has to improve the performance since we avoid extra read/writes of the entire KV cache.

I don't think Flash Attention will be implemented for every backend. At the very least, on the CPU, FA is slower so most likely we will fallback to ggml_soft_max_ext in that case. But likely for some of the other backends too.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not observing a significant performance difference for LLaMA:

GPU Model Test t/s master t/s gg/refactor-alibi Speedup
RTX 3090 llama 7B Q4_0 pp512 4105.91 4088.42 1.00
RTX 3090 llama 7B Q4_0 tg128 136.69 136.11 1.00

There are currently not any models with ALiBi where I personally care about the performance. But presumably someone else will complain if there is a regression. I don't have an ALiBi model ready for profiling but just by looking at the code the biggest difference seems to be that m0 and m1 are being calculated thousands of times of times on the GPU instead of once on the CPU.

More generally, the kernel would probably be a little faster if there was a template parameter to tell the compiler whether the mask and/or ALiBi is being used. Also the shared memory limit for the kernel can be manually raised to allow for better performance at large input sizes (I didn't know this at the time).

llama.cpp Outdated
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL)
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL")
#pragma message(" Falling back to ggml_alibi(). Will become and error in Mar 2024")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#pragma message(" Falling back to ggml_alibi(). Will become and error in Mar 2024")
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")

@slaren
Copy link
Collaborator

slaren commented Feb 15, 2024

I also see a signifnicant speedup in the M3 Max as well:

Model Model Size [GiB] Num. of Parameters Test t/s master t/s gg/refactor-alibi Speedup
refact 1B Q4_0 0.86 1585842176 pp512 2143.80 2778.23 1.30
refact 1B Q4_0 0.86 1585842176 tg128 141.62 148.71 1.05

I can see that removing an read/write from the entire KV cache can lead to a large speedup, in that case it would be worth it for the performance improvement alone. We should probably figure a way to make the implementation of fused ops optional in the backends before adding more, though.

@slaren
Copy link
Collaborator

slaren commented Feb 15, 2024

The issue with CUDA seems to be the double precision pow:

GPU Model Test t/s master t/s gg/refactor-alibi Speedup
RTX 3090 Ti refact 1B Q4_0 pp512 9019.70 11988.70 1.33
RTX 3090 Ti refact 1B Q4_0 tg128 211.86 224.65 1.06
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 7285d6de..5b6acf43 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -5974,14 +5974,14 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
     // ALiBi
     if (max_bias > 0.0f) {
         const uint32_t n_head_kv   = gridDim.x/nrows_y;
-        const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
+        const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2f((float) n_head_kv));

-        const float m0 = pow(2.0f, -(max_bias       ) / n_head_log2);
-        const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
+        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

         const int h = rowx/nrows_y; // head index

-        slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
+        slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
     }

     extern __shared__ float data_soft_max_f32[];

ggml-cuda.cu Outdated
Comment on lines 5979 to 5980
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are using pow, not powf. The consequence is that you are calculating the exponentiation at double precision (and then downcasting to float anyways). Notably the only NVIDIA GPUs with usable FP64 performance are high-end datacenter ones like the A100. On consumer GPUs like slaren's RTX 3090 ti FP64 arithmetic is 64 times slower than FP32 arithmetic (presumably for market segmentation).

ggml-cuda.cu Outdated

dst[idst + WARP_SIZE] = result.y;
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@ggerganov
Copy link
Owner Author

ggerganov commented Feb 15, 2024

We should probably figure a way to make the implementation of fused ops optional in the backends before adding more, though.

I don't think there is any point in fusing more ops for now because the major benefit from fusing in common use cases is in the ops that scale with the context (i.e. the attention) and ggml_soft_max_ext is basically everything we can fuse there (short of FA).

So no need to put high-priority on the optional fusing support, though we will do it eventually.

The issue with CUDA seems to be the double precision pow:

Ah, will fix it right away

Edit: this is the performance after the fix on RTX 2060

Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes

model size params backend ngl test t/s
refact 1B F16 2.95 GiB 1.59 B CUDA 99 pp 512 4066.55 ± 47.95
refact 1B F16 2.95 GiB 1.59 B CUDA 99 tg 128 93.47 ± 0.10

build: 9350a1c (2155) (master)

model size params backend ngl test t/s
refact 1B F16 2.95 GiB 1.59 B CUDA 99 pp 512 4974.54 ± 71.78
refact 1B F16 2.95 GiB 1.59 B CUDA 99 tg 128 95.10 ± 0.13

build: 113e0d5 (2161) (PR)

@slaren
Copy link
Collaborator

slaren commented Feb 15, 2024

It may also be worth pre-computing m0 and m1 to avoid doing this in the kernel in every thread. The difference is small in the faster devices, but it is still free performance.

GPU Model Model Size [GiB] Test t/s gg/refactor-alibi t/s alibi-test Speedup
RTX 3090 Ti refact 1B F16 2.84 pp512 13637.62 13988.27 1.03
RTX 3090 Ti refact 1B F16 2.84 tg128 154.55 155.92 1.01
RTX 3090 Ti refact 1B Q4_0 0.86 pp512 11966.69 12200.86 1.02
RTX 3090 Ti refact 1B Q4_0 0.86 tg128 224.56 227.22 1.01
GPU Model Model Size [GiB] Test t/s gg/refactor-alibi t/s alibi-test Speedup
RTX 3080 refact 1B F16 2.84 pp1024 7977.82 8300.09 1.04
RTX 3080 refact 1B F16 2.84 tg128 133.89 134.18 1.00
RTX 3080 refact 1B Q4_0 0.86 pp1024 6933.47 7388.41 1.07
RTX 3080 refact 1B Q4_0 0.86 tg128 200.85 202.11 1.01
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 79487391..77da1e66 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -5957,7 +5957,8 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
 }

 template <bool vals_smem, int ncols_template, int block_size_template>
-static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) {
+static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias,
+const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;

     const int tid  = threadIdx.x;
@@ -5973,15 +5974,11 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f

     // ALiBi
     if (max_bias > 0.0f) {
-        const uint32_t n_head_kv   = gridDim.x/nrows_y;
-        const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
-
-        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
         const int h = rowx/nrows_y; // head index

-        slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
+        float base = h < n_head_log2 ? m0 : m1;
+        int   exp =  h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        slope = powf(base, exp);
     }

     extern __shared__ float data_soft_max_f32[];
@@ -7482,39 +7479,46 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
     const dim3 block_nums(nrows_x, 1, 1);
     const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+    const uint32_t n_head_kv   = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
     if (shmem < g_device_caps[g_main_device].smpb) {
         switch (ncols_x) {
             case 32:
-                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 64:
-                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 128:
-                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 256:
-                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 512:
-                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 1024:
-                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 2048:
-                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 4096:
-                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             default:
-                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
         }
     } else {
         const size_t shmem_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
     }
 }

@slaren
Copy link
Collaborator

slaren commented Feb 15, 2024

I tried the same change in Metal, but thermal throttling makes it very hard to test these kind of changes. Basically whatever I run first is faster, but overall the difference seems even smaller than with CUDA.

CPU Model Model Size [GiB] Test t/s alibi-test t/s gg/refactor-alibi Speedup
M3 Max refact 1B F16 2.84 pp2048 2611.18 2387.61 0.91
M3 Max refact 1B Q4_0 0.86 pp2048 2505.99 2317.44 0.92
CPU Model Model Size [GiB] Test t/s gg/refactor-alibi t/s alibi-test Speedup
M3 Max refact 1B F16 2.84 pp2048 2602.00 2369.88 0.91
M3 Max refact 1B Q4_0 0.86 pp2048 2500.70 2285.08 0.91
diff --git a/ggml-metal.m b/ggml-metal.m
index 1c67334d..09df3a81 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1194,6 +1194,14 @@ static bool ggml_metal_graph_compute(
                         const float scale    = ((float *) dst->op_params)[0];
                         const float max_bias = ((float *) dst->op_params)[1];

+                        const int64_t nrows_x = ggml_nrows(src0);
+                        const int64_t nrows_y = src0->ne[1];
+                        const uint32_t n_head_kv   = nrows_x/nrows_y;
+                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                         if (id_src1) {
@@ -1212,6 +1220,9 @@ static bool ggml_metal_graph_compute(
                         [encoder setBytes:&ne02     length:sizeof(ne02)      atIndex:6];
                         [encoder setBytes:&scale    length:sizeof(scale)     atIndex:7];
                         [encoder setBytes:&max_bias length:sizeof(max_bias)  atIndex:8];
+                        [encoder setBytes:&m0       length:sizeof(m0)        atIndex:9];
+                        [encoder setBytes:&m1       length:sizeof(m1)        atIndex:10];
+                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
                         [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];

                         [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
diff --git a/ggml-metal.metal b/ggml-metal.metal
index c5da88e1..09ebcc9e 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -358,6 +358,9 @@ kernel void kernel_soft_max(
         constant   int64_t & ne02,
         constant     float & scale,
         constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -377,15 +380,12 @@ kernel void kernel_soft_max(

     // ALiBi
     if (max_bias > 0.0f) {
-        const uint32_t n_head_kv   = ne02;
-        const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
-
-        const float m0 = pow(2.0f, -(max_bias       ) / n_head_log2);
-        const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
         const int64_t h = i02;

-        slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
     }

     // parallel max
@@ -462,6 +462,9 @@ kernel void kernel_soft_max_4(
         constant   int64_t & ne02,
         constant     float & scale,
         constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -480,15 +483,12 @@ kernel void kernel_soft_max_4(
     float slope = 0.0f;

     if (max_bias > 0.0f) {
-        const uint32_t n_head_kv   = ne02;
-        const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
-
-        const float m0 = pow(2.0f, -(max_bias       ) / n_head_log2);
-        const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
         const int64_t h = i02;

-        slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
     }

     // parallel max

llama.cpp Outdated
@@ -7501,6 +7523,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}

{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This array can be quite big with large contexts, and is only used with alibi models.

Suggested change
{
if (hparams.f_max_alibi_bias != 0.0f) {

@ggerganov
Copy link
Owner Author

ggerganov commented Feb 16, 2024

thermal throttling makes it very hard to test these kind of changes

Hm, it's strange behaviour on M3 Max - I don't observe throttling on my M1 Pro MacBook nor on my Mac Studio. Wonder what's different

The Metal change does help a little bit, more noticeably for the smaller 1B model (-r 10):

Model Model Size [GiB] Test t/s 113e0d5 t/s gg/refactor-alibi Speedup
mpt 7B F16 (guessed) 12.77 pp512 1455.17 1457.46 1.00
mpt 7B F16 (guessed) 12.77 tg128 40.91 40.91 1.00
refact 1B F16 2.95 pp512 4903.18 4935.32 1.01
refact 1B F16 2.95 tg128 113.61 113.79 1.00

@ggerganov ggerganov requested a review from slaren February 17, 2024 16:19
test-backend-ops : replace soft_max tests

ggml-ci
@ggerganov ggerganov merged commit 8f1be0d into master Feb 17, 2024
62 of 63 checks passed
@ggerganov ggerganov deleted the gg/refactor-alibi branch February 17, 2024 21:04
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Feb 19, 2024
* ggml : avoid recomputing alibi slopes (CPU)

* llama : reuse hparams.f_max_alibi_bias in all cases

ggml-ci

* ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)

ggml-ci

* ggml : handle all SRCs (do not break on first null)

ggml-ci

* tests : do not use slope for large soft_max

accumulates too much error

ggml-ci

* ggml : alternative ALiBi without extra tensor

We compute the slopes in the kernel

ggml-ci

* cuda : add ALiBi support in ggml_soft_max_ext

ggml-ci

* ggml : deprecate ggml_alibi

* ggml : support multi-sequence ALiBi (Metal)

ggml-ci

* cuda : add multi-seq ALiBi + remote F16 soft_max

ggml-ci

* ggml : update deprecation message

* ggml : fix pos ptr when no ALiBi

ggml-ci

* cuda : fix performance (pow -> powf)

* cuda : precompute ALiBi constants

* metal : pre-compute ALiBi slopes

ggml-ci

* llama : init kq_pos only if needed

ggml-ci

* test-backend-ops : add null pos test to soft_max

test-backend-ops : replace soft_max tests

ggml-ci

---------

Co-authored-by: slaren <[email protected]>
@airMeng airMeng mentioned this pull request Feb 21, 2024
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* ggml : avoid recomputing alibi slopes (CPU)

* llama : reuse hparams.f_max_alibi_bias in all cases

ggml-ci

* ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)

ggml-ci

* ggml : handle all SRCs (do not break on first null)

ggml-ci

* tests : do not use slope for large soft_max

accumulates too much error

ggml-ci

* ggml : alternative ALiBi without extra tensor

We compute the slopes in the kernel

ggml-ci

* cuda : add ALiBi support in ggml_soft_max_ext

ggml-ci

* ggml : deprecate ggml_alibi

* ggml : support multi-sequence ALiBi (Metal)

ggml-ci

* cuda : add multi-seq ALiBi + remote F16 soft_max

ggml-ci

* ggml : update deprecation message

* ggml : fix pos ptr when no ALiBi

ggml-ci

* cuda : fix performance (pow -> powf)

* cuda : precompute ALiBi constants

* metal : pre-compute ALiBi slopes

ggml-ci

* llama : init kq_pos only if needed

ggml-ci

* test-backend-ops : add null pos test to soft_max

test-backend-ops : replace soft_max tests

ggml-ci

---------

Co-authored-by: slaren <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* ggml : avoid recomputing alibi slopes (CPU)

* llama : reuse hparams.f_max_alibi_bias in all cases

ggml-ci

* ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)

ggml-ci

* ggml : handle all SRCs (do not break on first null)

ggml-ci

* tests : do not use slope for large soft_max

accumulates too much error

ggml-ci

* ggml : alternative ALiBi without extra tensor

We compute the slopes in the kernel

ggml-ci

* cuda : add ALiBi support in ggml_soft_max_ext

ggml-ci

* ggml : deprecate ggml_alibi

* ggml : support multi-sequence ALiBi (Metal)

ggml-ci

* cuda : add multi-seq ALiBi + remote F16 soft_max

ggml-ci

* ggml : update deprecation message

* ggml : fix pos ptr when no ALiBi

ggml-ci

* cuda : fix performance (pow -> powf)

* cuda : precompute ALiBi constants

* metal : pre-compute ALiBi slopes

ggml-ci

* llama : init kq_pos only if needed

ggml-ci

* test-backend-ops : add null pos test to soft_max

test-backend-ops : replace soft_max tests

ggml-ci

---------

Co-authored-by: slaren <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants