-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add ALiBi support for ggml_soft_max_ext #5488
Conversation
Isn't this softmax part of flash attention? Are we giving up on that? |
Regarding the FA branch - I've hit a roadblock with the CUDA implementation, and don't want to merge just Metal support. But once this is resolved, the |
accumulates too much error ggml-ci
We compute the slopes in the kernel ggml-ci
Is there significant benefit from the F16 CUDA soft_max implementation? Lines 9132 to 9148 in a0f8a93
I think it is currently broken on Also, while on this topic, I see we have custom template specifications for a set of sizes in CUDA soft_max: Lines 7673 to 7706 in a0f8a93
How much do we gain from these? I assume there is some benefit for prompt processing, because for TG we almost never hit these branches. cc @JohannesGaessler Do you have some numbers handy? |
Do consider that if you ever want to implement sampling on the GPU softmax will also be needed.
As I said before, my intention with FA was to let Steward/FSSRepo handle it (with me advising) since long term I would like to have more devs. But recently there hasn't really been progress from his side so maybe I'll just do it myself in the next few weeks.
A quick low sample size test where I enabled FP16 soft max and nothing else via a simple patch:
The FP16 kernel itself is ~1.4x faster than the FP32 kernel.
This is an issue with how the tests are set up. The worst precision I get for FP16 softmax is NMSE = 4e-6. This is worse than the threshold of 1e-7 but still completely negligible. In the first place, NMSE does not accurately reflect the effects of numerical precision for softmax. NMSE is dominated by the effects on values that are ~0 after softmax but those values are going to be ignored anyways and even large relative differences like 1e-9 <-> 1e-10 on those values are completely meaningless. Also notice how NMSE increases with the input size. I didn't analyze the code in detail but what I think is happening is that the number of values that are ~0 after softmax increases. And because the numerical stability for those values is worse worse NMSE increases. But this does not make a practical difference. |
From the templating in particular I think the difference was ~10% in terms of the kernel runtime. |
On diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 9af8517d..92adf4fb 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -2091,7 +2091,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng);
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, 0.1f));
}
exponent <<= 1; Some of the tests fail with
They don't fail with the default scale factor of I'm thinking we should probably drop the F16 variant of soft_max for now - does not seem really worth to me. Long-term when |
Would be fine with me. |
733c477
to
996f7f4
Compare
This seems to cause a significant performance regression in alibi models:
This can probably be fixed, but I don't really understand what is the goal of this change. I doesn't seem to improve performance, and it will require changes in every backend, it will become obsolete after the introduction of flash attention, and in my opinion the code is not really any simpler, it replaces a very simple |
I suppose I did something wrong in the CUDA implementation. Here are the results with Metal:
The It has to improve the performance since we avoid extra read/writes of the entire KV cache. I don't think Flash Attention will be implemented for every backend. At the very least, on the CPU, FA is slower so most likely we will fallback to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not observing a significant performance difference for LLaMA:
GPU | Model | Test | t/s master | t/s gg/refactor-alibi | Speedup |
---|---|---|---|---|---|
RTX 3090 | llama 7B Q4_0 | pp512 | 4105.91 | 4088.42 | 1.00 |
RTX 3090 | llama 7B Q4_0 | tg128 | 136.69 | 136.11 | 1.00 |
There are currently not any models with ALiBi where I personally care about the performance. But presumably someone else will complain if there is a regression. I don't have an ALiBi model ready for profiling but just by looking at the code the biggest difference seems to be that m0
and m1
are being calculated thousands of times of times on the GPU instead of once on the CPU.
More generally, the kernel would probably be a little faster if there was a template parameter to tell the compiler whether the mask and/or ALiBi is being used. Also the shared memory limit for the kernel can be manually raised to allow for better performance at large input sizes (I didn't know this at the time).
llama.cpp
Outdated
// temporary branch until we figure out how to handle ggml_alibi through ggml_add | ||
#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL) | ||
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL") | ||
#pragma message(" Falling back to ggml_alibi(). Will become and error in Mar 2024") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#pragma message(" Falling back to ggml_alibi(). Will become and error in Mar 2024") | |
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") |
I also see a signifnicant speedup in the M3 Max as well:
I can see that removing an read/write from the entire KV cache can lead to a large speedup, in that case it would be worth it for the performance improvement alone. We should probably figure a way to make the implementation of fused ops optional in the backends before adding more, though. |
The issue with CUDA seems to be the double precision
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 7285d6de..5b6acf43 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -5974,14 +5974,14 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
// ALiBi
if (max_bias > 0.0f) {
const uint32_t n_head_kv = gridDim.x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2f((float) n_head_kv));
- const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int h = rowx/nrows_y; // head index
- slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
+ slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
}
extern __shared__ float data_soft_max_f32[]; |
ggml-cuda.cu
Outdated
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); | ||
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you are using pow
, not powf
. The consequence is that you are calculating the exponentiation at double precision (and then downcasting to float anyways). Notably the only NVIDIA GPUs with usable FP64 performance are high-end datacenter ones like the A100. On consumer GPUs like slaren's RTX 3090 ti FP64 arithmetic is 64 times slower than FP32 arithmetic (presumably for market segmentation).
ggml-cuda.cu
Outdated
|
||
dst[idst + WARP_SIZE] = result.y; | ||
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
I don't think there is any point in fusing more ops for now because the major benefit from fusing in common use cases is in the ops that scale with the context (i.e. the attention) and So no need to put high-priority on the optional fusing support, though we will do it eventually.
Ah, will fix it right away Edit: this is the performance after the fix on RTX 2060 Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
build: 9350a1c (2155) (master)
build: 113e0d5 (2161) (PR) |
It may also be worth pre-computing
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 79487391..77da1e66 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -5957,7 +5957,8 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
}
template <bool vals_smem, int ncols_template, int block_size_template>
-static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) {
+static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias,
+const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@@ -5973,15 +5974,11 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
// ALiBi
if (max_bias > 0.0f) {
- const uint32_t n_head_kv = gridDim.x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
-
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
const int h = rowx/nrows_y; // head index
- slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
+ float base = h < n_head_log2 ? m0 : m1;
+ int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ slope = powf(base, exp);
}
extern __shared__ float data_soft_max_f32[];
@@ -7482,39 +7479,46 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+ const uint32_t n_head_kv = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
if (shmem < g_device_caps[g_main_device].smpb) {
switch (ncols_x) {
case 32:
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
}
} |
I tried the same change in Metal, but thermal throttling makes it very hard to test these kind of changes. Basically whatever I run first is faster, but overall the difference seems even smaller than with CUDA.
diff --git a/ggml-metal.m b/ggml-metal.m
index 1c67334d..09df3a81 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1194,6 +1194,14 @@ static bool ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0];
const float max_bias = ((float *) dst->op_params)[1];
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+ const uint32_t n_head_kv = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
@@ -1212,6 +1220,9 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
diff --git a/ggml-metal.metal b/ggml-metal.metal
index c5da88e1..09ebcc9e 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -358,6 +358,9 @@ kernel void kernel_soft_max(
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@@ -377,15 +380,12 @@ kernel void kernel_soft_max(
// ALiBi
if (max_bias > 0.0f) {
- const uint32_t n_head_kv = ne02;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
-
- const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
const int64_t h = i02;
- slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
}
// parallel max
@@ -462,6 +462,9 @@ kernel void kernel_soft_max_4(
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@@ -480,15 +483,12 @@ kernel void kernel_soft_max_4(
float slope = 0.0f;
if (max_bias > 0.0f) {
- const uint32_t n_head_kv = ne02;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
-
- const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
const int64_t h = i02;
- slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
}
// parallel max |
llama.cpp
Outdated
@@ -7501,6 +7523,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |||
} | |||
} | |||
|
|||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This array can be quite big with large contexts, and is only used with alibi models.
{ | |
if (hparams.f_max_alibi_bias != 0.0f) { |
Hm, it's strange behaviour on M3 Max - I don't observe throttling on my M1 Pro MacBook nor on my Mac Studio. Wonder what's different The Metal change does help a little bit, more noticeably for the smaller 1B model (
|
test-backend-ops : replace soft_max tests ggml-ci
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <[email protected]>
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <[email protected]>
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <[email protected]>
ref: #3470
Missing backends currently generate compile warnings. For some time, we will fallback to the old implementation of
ggml_alibi
but this will be removed at some point, so support should be added in all backends when possible.Additional changes:
KQ_pos
tensor. It provides the token position for each KV cache cell