fix KV cache padding, NaN from INFINITY (#6438)
This commit is contained in:
parent
c63dfdf765
commit
ee19a4ab7e
2 changed files with 11 additions and 8 deletions
|
@ -4,6 +4,7 @@
|
|||
#include <mma.h>
|
||||
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
|
||||
|
@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
KQ[tid] = -INFINITY;
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half kqmax = -INFINITY;
|
||||
half kqmax = -HALF_MAX_HALF;
|
||||
half kqsum = 0.0f;
|
||||
|
||||
__shared__ half kqmax_shared[WARP_SIZE];
|
||||
__shared__ half kqsum_shared[WARP_SIZE];
|
||||
if (threadIdx.y == 0) {
|
||||
kqmax_shared[threadIdx.x] = -INFINITY;
|
||||
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
|
||||
kqsum_shared[threadIdx.x] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
if (tid < D) {
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D; k0 += 2) {
|
||||
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
|
||||
half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}};
|
||||
half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}};
|
||||
half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}};
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
|
||||
const cudaStream_t main_stream = ctx.stream();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue