This commit is contained in:
Siddartha Naidu 2025-02-02 10:21:59 -05:00 committed by GitHub
commit 635628a589
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 72 additions and 10 deletions

View file

@ -34,6 +34,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;

View file

@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
}
}
static __global__ void pad_f16(const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = 0.0f;
}
}
static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
@ -33,17 +58,33 @@ static void pad_f32_cuda(const float * x, float * dst,
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
static void pad_f16_cuda(const half * x, half * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f16<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
if (src0->type == GGML_TYPE_F32) {
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
} else {
const half * src0_d = (const half *)src0->data;
half * dst_d = (half *)dst->data;
pad_f16_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
}

View file

@ -589,12 +589,30 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
struct ggml_tensor * padded_v = v;
int64_t n_embd_head_v_out = n_embd_head_v;
if (n_embd_head_v < n_embd_head_k) {
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
cb(padded_v, "padded_v", il);
n_embd_head_v_out = n_embd_head_k;
}
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
if (n_embd_head_v < n_embd_head_k) {
cur = ggml_reshape_3d(ctx, cur, n_embd_head_v_out, n_head, n_tokens);
cur = ggml_view_3d(ctx, cur, n_embd_head_v, n_head, n_tokens,
ggml_element_size(cur) * n_embd_head_v_out,
ggml_element_size(cur) * n_embd_head_v_out * n_head,
0);
cur = ggml_cont(ctx, cur);
}
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@ -9564,8 +9582,8 @@ struct llama_context * llama_init_from_model(
params.flash_attn = false;
}
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;
}