diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 1d29346c7..91ef5551e 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -4,6 +4,7 @@ #include #define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16( KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; - half kqmax = -INFINITY; + half kqmax = -HALF_MAX_HALF; half kqsum = 0.0f; __shared__ half kqmax_shared[WARP_SIZE]; __shared__ half kqsum_shared[WARP_SIZE]; if (threadIdx.y == 0) { - kqmax_shared[threadIdx.x] = -INFINITY; + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; kqsum_shared[threadIdx.x] = 0.0f; } __syncthreads(); @@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16( if (tid < D) { #pragma unroll for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16( __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; + half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + ggml_cuda_set_device(ctx.device); const cudaStream_t main_stream = ctx.stream(); diff --git a/llama.cpp b/llama.cpp index 9ea9886fe..b50588e44 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9973,7 +9973,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;