Add support for Deepseek-R1 flash attention
This commit is contained in:
parent
a83f528688
commit
4c59b04ac8
3 changed files with 72 additions and 10 deletions
|
@ -34,6 +34,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
|
|
|
@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void pad_f16(const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
|
||||
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
|
||||
// blockIdx.y: idx of ne1
|
||||
// blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
blockIdx.z * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static void pad_f32_cuda(const float * x, float * dst,
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
|
@ -33,17 +58,33 @@ static void pad_f32_cuda(const float * x, float * dst,
|
|||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
|
||||
}
|
||||
|
||||
static void pad_f16_cuda(const half * x, half * dst,
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||
pad_f16<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == src0->type);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
} else {
|
||||
const half * src0_d = (const half *)src0->data;
|
||||
half * dst_d = (half *)dst->data;
|
||||
pad_f16_cuda(src0_d, dst_d,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -589,12 +589,30 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
struct ggml_tensor * padded_v = v;
|
||||
int64_t n_embd_head_v_out = n_embd_head_v;
|
||||
if (n_embd_head_v < n_embd_head_k) {
|
||||
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
|
||||
cb(padded_v, "padded_v", il);
|
||||
n_embd_head_v_out = n_embd_head_k;
|
||||
}
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
if (n_embd_head_v < n_embd_head_k) {
|
||||
cur = ggml_reshape_3d(ctx, cur, n_embd_head_v_out, n_head, n_tokens);
|
||||
cur = ggml_view_3d(ctx, cur, n_embd_head_v, n_head, n_tokens,
|
||||
ggml_element_size(cur) * n_embd_head_v_out,
|
||||
ggml_element_size(cur) * n_embd_head_v_out * n_head,
|
||||
0);
|
||||
cur = ggml_cont(ctx, cur);
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
||||
|
||||
} else {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
@ -9551,8 +9569,8 @@ struct llama_context * llama_init_from_model(
|
|||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
||||
if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue