fix KV cache padding, NaN from INFINITY (#6438)

This commit is contained in:
Johannes Gäßler 2024-04-02 17:26:22 +02:00 committed by GitHub
parent c63dfdf765
commit ee19a4ab7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 8 deletions

View file

@ -4,6 +4,7 @@
#include <mma.h>
#define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
template<int D, int parallel_blocks> // D == head size
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16(
KQ[tid] = -INFINITY;
half2 * KQ2 = (half2 *) KQ;
half kqmax = -INFINITY;
half kqmax = -HALF_MAX_HALF;
half kqsum = 0.0f;
__shared__ half kqmax_shared[WARP_SIZE];
__shared__ half kqsum_shared[WARP_SIZE];
if (threadIdx.y == 0) {
kqmax_shared[threadIdx.x] = -INFINITY;
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
kqsum_shared[threadIdx.x] = 0.0f;
}
__syncthreads();
@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16(
if (tid < D) {
#pragma unroll
for (int k0 = 0; k0 < D; k0 += 2) {
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
break;
}
@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16(
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
half2 * KQ2 = (half2 *) KQ;
half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}};
half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}};
half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}};
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_set_device(ctx.device);
const cudaStream_t main_stream = ctx.stream();

View file

@ -9973,7 +9973,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
// this is necessary due to kv_self.n being padded later during inference
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;