Add support for Deepseek-R1 flash attention

This commit is contained in:
Siddartha Naidu 2025-01-31 18:48:48 +00:00
parent a83f528688
commit 4c59b04ac8
3 changed files with 72 additions and 10 deletions

View file

@ -34,6 +34,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
case 128: case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break; break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst);
break;
case 256: case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break; break;

View file

@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
} }
} }
static __global__ void pad_f16(const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = 0.0f;
}
}
static void pad_f32_cuda(const float * x, float * dst, static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03, const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
@ -33,17 +58,33 @@ static void pad_f32_cuda(const float * x, float * dst,
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03); pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
} }
static void pad_f16_cuda(const half * x, half * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f16<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
pad_f32_cuda(src0_d, dst_d, if (src0->type == GGML_TYPE_F32) {
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], const float * src0_d = (const float *)src0->data;
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); float * dst_d = (float *)dst->data;
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
} else {
const half * src0_d = (const half *)src0->data;
half * dst_d = (half *)dst->data;
pad_f16_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
} }

View file

@ -589,12 +589,30 @@ static struct ggml_tensor * llm_build_kqv(
0); 0);
cb(v, "v", il); cb(v, "v", il);
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, struct ggml_tensor * padded_v = v;
int64_t n_embd_head_v_out = n_embd_head_v;
if (n_embd_head_v < n_embd_head_k) {
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
cb(padded_v, "padded_v", il);
n_embd_head_v_out = n_embd_head_k;
}
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
if (n_embd_head_v < n_embd_head_k) {
cur = ggml_reshape_3d(ctx, cur, n_embd_head_v_out, n_head, n_tokens);
cur = ggml_view_3d(ctx, cur, n_embd_head_v, n_head, n_tokens,
ggml_element_size(cur) * n_embd_head_v_out,
ggml_element_size(cur) * n_embd_head_v_out * n_head,
0);
cur = ggml_cont(ctx, cur);
}
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else { } else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
@ -9551,8 +9569,8 @@ struct llama_context * llama_init_from_model(
params.flash_attn = false; params.flash_attn = false;
} }
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false; params.flash_attn = false;
} }