From 4c59b04ac8ba623210f19ce08d80f395026b36ea Mon Sep 17 00:00:00 2001 From: Siddartha Naidu Date: Fri, 31 Jan 2025 18:48:48 +0000 Subject: [PATCH] Add support for Deepseek-R1 flash attention --- ggml/src/ggml-cuda/fattn.cu | 3 ++ ggml/src/ggml-cuda/pad.cu | 55 ++++++++++++++++++++++++++++++++----- src/llama.cpp | 24 ++++++++++++++-- 3 files changed, 72 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0b26b0f8e..07163b37c 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -34,6 +34,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g case 128: ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst); + break; case 256: ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); break; diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index aba539e8d..353a89589 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons } } +static __global__ void pad_f16(const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { + // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 + // blockIdx.y: idx of ne1 + // blockIDx.x: idx of ne0 / BLOCK_SIZE + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + dst[offset_dst] = 0.0f; + } +} + static void pad_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03, const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { @@ -33,17 +58,33 @@ static void pad_f32_cuda(const float * x, float * dst, pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); } +static void pad_f16_cuda(const half * x, half * dst, + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2*ne3); + pad_f16<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); +} + void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - pad_f32_cuda(src0_d, dst_d, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + if (src0->type == GGML_TYPE_F32) { + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + pad_f32_cuda(src0_d, dst_d, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + } else { + const half * src0_d = (const half *)src0->data; + half * dst_d = (half *)dst->data; + pad_f16_cuda(src0_d, dst_d, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + } } diff --git a/src/llama.cpp b/src/llama.cpp index 192b20a27..c081ba1b2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -589,12 +589,30 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + struct ggml_tensor * padded_v = v; + int64_t n_embd_head_v_out = n_embd_head_v; + if (n_embd_head_v < n_embd_head_k) { + padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0); + cb(padded_v, "padded_v", il); + n_embd_head_v_out = n_embd_head_k; + } + + cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + if (n_embd_head_v < n_embd_head_k) { + cur = ggml_reshape_3d(ctx, cur, n_embd_head_v_out, n_head, n_tokens); + cur = ggml_view_3d(ctx, cur, n_embd_head_v, n_head, n_tokens, + ggml_element_size(cur) * n_embd_head_v_out, + ggml_element_size(cur) * n_embd_head_v_out * n_head, + 0); + cur = ggml_cont(ctx, cur); + } + cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); + } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); @@ -9551,8 +9569,8 @@ struct llama_context * llama_init_from_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { - LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); + if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) { + LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; }