Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exllama GPTQ CUDA kernel support #553

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ee7ba48
add exllama gptq kernel
fxmarty Jul 5, 2023
c858d79
add attribution
fxmarty Jul 5, 2023
0ff8219
Merge branch 'main' into gptq-cuda-kernels
fxmarty Jul 5, 2023
2272b3a
some more cleanup
fxmarty Jul 5, 2023
620ed7d
Merge branch 'gptq-cuda-kernels' of https://github.com/fxmarty/text-g…
fxmarty Jul 5, 2023
a6e3874
try-catch to load the cuda extension, quite ugly practice tbh
fxmarty Jul 5, 2023
4462854
have a single gptq quantization type
fxmarty Jul 12, 2023
67a46b7
move exllama buffer init to the top level
fxmarty Jul 12, 2023
67d6876
cleanup
fxmarty Jul 12, 2023
f90c61a
support bits different than 4
fxmarty Jul 12, 2023
8645fd3
tests
fxmarty Jul 12, 2023
faa5b52
Merge branch 'main' into gptq-cuda-kernels
fxmarty Jul 12, 2023
38c2be5
fix test
fxmarty Jul 12, 2023
2ae65b4
fix tests
fxmarty Jul 13, 2023
0036084
support all, test llama
fxmarty Jul 13, 2023
9401e10
Merge branch 'main' into gptq-cuda-kernels
fxmarty Jul 13, 2023
74e6d6e
fix the usual merge mess
fxmarty Jul 13, 2023
edfbfdf
Merge branch 'main' into gptq-cuda-kernels
fxmarty Jul 19, 2023
6bf7090
fix per-column quantization
fxmarty Jul 19, 2023
0860394
Refactored a bit.
Narsil Jul 20, 2023
8cf7c89
Small polish.
Narsil Jul 20, 2023
7faef69
Give escape hatch to not use exllama kernels even if available.
Narsil Jul 20, 2023
900ac49
Fixing GTPQ device santacoder.
Narsil Jul 20, 2023
12191b7
Fix config.
Narsil Jul 20, 2023
c6e702f
Add kernel target.
Narsil Jul 20, 2023
3ec3add
Separate build process.
Narsil Jul 20, 2023
40be532
Update starcoder_gptq
Narsil Jul 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ run-bloom:

run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080

clean:
rm -rf target aml
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod env_runtime;
enum Quantization {
Bitsandbytes,
Gptq,
Gptq_cuda,
}

impl std::fmt::Display for Quantization {
Expand All @@ -31,6 +32,9 @@ impl std::fmt::Display for Quantization {
Quantization::Gptq => {
write!(f, "gptq")
}
Quantization::Gptq_cuda => {
write!(f, "gptq-cuda")
}
}
}
}
Expand Down
71 changes: 71 additions & 0 deletions server/custom_kernels/custom_kernels/exllama/cuda_buffers.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#define _cuda_buffers_cu
#include "cuda_buffers.cuh"

CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;

CudaBuffers::CudaBuffers
(
int _device,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);

cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}

CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}

CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}

void prepare_buffers_cuda
(
int _device,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state,
_temp_dq
);

g_buffers[_device] = buffers;
}

void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}
52 changes: 52 additions & 0 deletions server/custom_kernels/custom_kernels/exllama/cuda_buffers.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>

const int CUDA_MAX_DEVICES = 16;

// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif

class CudaBuffers
{
public:
int device;

half* temp_state; // [max_hidden_rows * intermediate_size]
half* temp_dq; // size of largest quant tensor * 8

cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;

CudaBuffers
(
int _device,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};

CudaBuffers* get_buffers(const int device_index);

void prepare_buffers_cuda
(
int _device,
half* _temp_state,
half* _temp_dq
);

void cleanup_buffers_cuda();

#endif
58 changes: 58 additions & 0 deletions server/custom_kernels/custom_kernels/exllama/cuda_compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh

// atomicAdd for half types, to support CC < 7.x

__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;

do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}

// atomicAdd for half2 types

__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}

//

#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ < 700

__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }

#if __CUDA_ARCH__ < 600
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif

#endif
#endif

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#include "column_remap.cuh"
#include "../util.cuh"

const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;

__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;

int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;

int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;

int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;

while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}

// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w

void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);

dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);

column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _column_remap_cuh
#define _column_remap_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>

void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);

#endif
Loading