diff --git a/README.md b/README.md index 2f6e6ffeed098..866aa87b4ffc5 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics +- New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow - Collecting Apple Silicon performance stats: - M-series: https://github.com/ggerganov/llama.cpp/discussions/4167 - A-series: https://github.com/ggerganov/llama.cpp/discussions/4508 @@ -136,6 +137,7 @@ as the main playground for developing new features for the [ggml](https://github - [semperai/amica](https://github.com/semperai/amica) - [psugihara/FreeChat](https://github.com/psugihara/FreeChat) - [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) +- [iohub/collama](https://github.com/iohub/coLLaMA) --- diff --git a/common/common.cpp b/common/common.cpp index 6b4913a656573..4e89fe516e0a9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -920,7 +920,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { #endif printf(" -gan N, --grp-attn-n N\n"); printf(" group-attention factor (default: %d)\n", params.grp_attn_n); - printf(" -gat N, --grp-attn-w N\n"); + printf(" -gaw N, --grp-attn-w N\n"); printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); printf(" --verbose-prompt print prompt before generation\n"); printf(" -dkvc, --dump-kv-cache\n"); diff --git a/examples/server/README.md b/examples/server/README.md index 243e669912cf0..d85a14f891bc4 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -23,6 +23,7 @@ Command line options: - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--port`: Set the port to listen. Default: `8080`. - `--path`: path from which to serve static files (default examples/server/public) +- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. - `--embedding`: Enable embedding extraction, Default: disabled. - `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1) - `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled) @@ -174,35 +175,44 @@ node index.js `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) - *Result JSON:* +### Result JSON: - Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion. +* Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion. - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) +- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure: - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model` - - `model`: The path to the model loaded with `-m` - - `prompt`: The provided `prompt` - - `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token - - `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered - - `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided - - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) +``` +{ + "content": "", + "probs": [ + { + "prob": float, + "tok_str": "" + }, + { + "prob": float, + "tok_str": "" + }, + ... + ] +}, +``` +Notice that each `probs` is an array of length `n_probs`. + +- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. +- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) +- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model` +- `model`: The path to the model loaded with `-m` +- `prompt`: The provided `prompt` +- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token +- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered +- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided +- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) +- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` +- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) +- `tokens_evaluated`: Number of tokens evaluated in total from the prompt +- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) - **POST** `/tokenize`: Tokenize a given text. diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e0ea890b1afd8..e26260a35bcbd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -116,6 +116,7 @@ #include "ggml.h" #include "ggml-backend-impl.h" +#define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 #define CC_OFFSET_AMD 1000000 @@ -556,11 +557,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; struct cuda_device_capabilities { int cc; // compute capability + size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory }; -static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; +static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} }; static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default @@ -593,6 +595,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + (void) a; + bad_arch(); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -601,6 +616,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + (void) x; + bad_arch(); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -5385,75 +5413,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { +template +static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL + const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; + const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; + + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + extern __shared__ half data_soft_max_f16[]; + half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication + // (shared memory) buffer to cache values between iterations: + half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data); + // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead + // in that case col_smem == col_data must be enforced to avoid race conditions + + half2 max_val = make_half2(-INFINITY, -INFINITY); + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; + const int col_smem = vals_smem ? col0 + tid : col_data; + + const int ix = rowx*ncols_data + col_data; + const int iy = rowy*ncols_data + col_data; + + half2 val; + if (need_check && col_data + 0 >= ncols_data) { + val.x = -INFINITY; + } else { + val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + } + if (need_check && col_data + WARP_SIZE >= ncols_data) { + val.y = -INFINITY; + } else { + val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + } + if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { + vals[col_smem] = val; + } + max_val = __hmax2(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = __hmax(max_val.x, max_val.y); + } + __syncthreads(); + + max_val = __half2half2(buf_iw[lane_id]); + max_val = warp_reduce_max(max_val); + } else { + max_val = __half2half2(__hmax(max_val.x, max_val.y)); + } + + half2 tmp = make_half2(0.0f, 0.0f); // partial sums + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id; + + if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) { + break; + } + + const half2 val = h2exp(vals[col_smem] - max_val); + + tmp += val; + vals[col_smem] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp.x + tmp.y; + } + __syncthreads(); + + tmp = __half2half2(buf_iw[lane_id]); + tmp = warp_reduce_sum(tmp); + } else { + tmp = __half2half2(tmp.x + tmp.y); + } + + const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; + const int col_smem = vals_smem ? col0 + tid : col_data; + + const int idst = rowx*ncols_data + col_data; + const half2 result = vals[col_smem] * inv_sum; + + if (need_check && col_data + 0 >= ncols_data) { + return; + } + dst[idst] = result.x; + + if (need_check && col_data + WARP_SIZE >= ncols_data) { + return; + } + + dst[idst + WARP_SIZE] = result.y; + } +#else + (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale; + bad_arch(); +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +template +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = blockDim.x; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + extern __shared__ float data_soft_max_f32[]; + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; float max_val = -INFINITY; - for (int col = tid; col < ncols; col += block_size) { +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); + + const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + vals[col] = val; + max_val = max(max_val, val); } // find the max value in the block max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; + buf_iw[lane_id] = -INFINITY; } __syncthreads(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = max_val; } __syncthreads(); - max_val = buf[lane_id]; + max_val = buf_iw[lane_id]; max_val = warp_reduce_max(max_val); } - float tmp = 0.f; + float tmp = 0.0f; // partial sum - for (int col = tid; col < ncols; col += block_size) { - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = expf(vals[col] - max_val); tmp += val; - dst[ix] = val; + vals[col] = val; } // find the sum of exps in the block tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; } __syncthreads(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp; } __syncthreads(); - tmp = buf[lane_id]; + tmp = buf_iw[lane_id]; tmp = warp_reduce_sum(tmp); } - const float inv_tmp = 1.f / tmp; + const float inv_sum = 1.0f / tmp; - for (int col = tid; col < ncols; col += block_size) { - const int i = rowx*ncols + col; - dst[i] *= inv_tmp; +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + const int idst = rowx*ncols + col; + dst[idst] = vals[col] * inv_sum; } } @@ -6752,12 +6938,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } +static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { + int nth = WARP_SIZE; + while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(nrows_x, 1, 1); + const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + if (shmem <= g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(half); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + } +} + static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + if (shmem < g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + } } static void im2col_f32_f16_cuda(const float* x, half* dst, @@ -7072,6 +7336,7 @@ void ggml_init_cublas() { #else g_device_caps[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + g_device_caps[id].smpb = prop.sharedMemPerBlock; } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -8087,7 +8352,21 @@ static void ggml_cuda_op_soft_max( float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + const bool use_f16_soft_max = false; +#else +#ifdef GGML_CUDA_F16 + const bool use_f16_soft_max = true; +#else + const bool use_f16_soft_max = false; +#endif // GGML_CUDA_F16 +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + + if (use_f16_soft_max) { + soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + } else { + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + } (void) dst; } diff --git a/ggml-quants.c b/ggml-quants.c index fd127f2d1558a..d497e6de9ceb5 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -7250,9 +7250,9 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res uint32_t aux32[4]; const uint8_t * aux8 = (const uint8_t *)aux32; - int8x16x4_t q2u; - int8x16x4_t q2s; - int8x16x4_t q8b; + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -7261,7 +7261,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res const int8_t * restrict q8 = y[i].qs; float sumf1 = 0, sumf2 = 0; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = vld1q_s8_x4(q8); q8 += 64; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); diff --git a/scripts/get-pg.sh b/scripts/get-pg.sh new file mode 100755 index 0000000000000..d516db46cf01f --- /dev/null +++ b/scripts/get-pg.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +function usage { + echo "usage: $0" + exit 1 +} + +function has_cmd { + if ! [ -x "$(command -v $1)" ]; then + echo "error: $1 is not available" >&2 + exit 1 + fi +} + +# check for: curl, html2text, tail, sed, fmt +has_cmd curl +has_cmd html2text +has_cmd tail +has_cmd sed + +if [ $# -ne 1 ]; then + usage +fi + +n=$1 + +# get urls +urls="$(curl http://www.aaronsw.com/2002/feeds/pgessays.rss | grep html | sed -e "s/.*http/http/" | sed -e "s/html.*/html/" | head -n $n)" + +printf "urls:\n%s\n" "$urls" + +if [ -f pg.txt ]; then + rm pg.txt +fi + +for url in $urls; do + echo "processing $url" + + curl -L $url | html2text | tail -n +4 | sed -E "s/^[[:space:]]+//g" | fmt -w 80 >> pg.txt + + # don't flood the server + sleep 1 +done + +echo "done. data in pg.txt" + +exit 0 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b79de7a7dd5cc..7a60d77431e30 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -450,7 +450,7 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -1449,6 +1449,7 @@ struct test_moe : public test_case { static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; + std::default_random_engine rng(0); const ggml_type all_types[] = { GGML_TYPE_F32, GGML_TYPE_F16, @@ -1583,7 +1584,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); - test_cases.emplace_back(new test_soft_max()); + std::uniform_int_distribution<> dist_ne1(1, 50); + int exponent = 1; + while (exponent < (1 << 17)) { + std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent); + + for (int n = 0; n < 10; ++n) { + int64_t ne0 = dist_ne0(rng); + int64_t ne1 = dist_ne1(rng); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1})); + } + + exponent <<= 1; + } for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B