CPU/CUDA: Gemma 2 FlashAttention support (#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This commit is contained in:
parent
8f824ffe8e
commit
e11bd856d5
12 changed files with 319 additions and 79 deletions
|
@ -7095,7 +7095,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
float max_bias,
|
||||
float logit_softcap) {
|
||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||
// TODO: check if vT can be multiplied by (k*qT)
|
||||
|
||||
|
@ -7122,7 +7123,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
float params[] = { scale, max_bias };
|
||||
float params[] = { scale, max_bias, logit_softcap };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||
|
@ -7142,7 +7143,7 @@ void ggml_flash_attn_ext_set_prec(
|
|||
|
||||
const int32_t prec_i32 = (int32_t) prec;
|
||||
|
||||
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||
}
|
||||
|
||||
// ggml_flash_attn_back
|
||||
|
@ -15271,11 +15272,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
@ -15339,7 +15346,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
s = s*scale + mv; // scale KQ value and apply mask
|
||||
s = s*scale; // scale KQ value
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap*tanhf(s);
|
||||
}
|
||||
|
||||
s += mv; // apply mask
|
||||
|
||||
const float Mold = M;
|
||||
|
||||
|
@ -15348,7 +15361,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
|
||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
if (v->type== GGML_TYPE_F16) {
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
|
@ -15415,7 +15428,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
|||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * mask,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->op_params[2]) {
|
||||
switch (dst->op_params[3]) {
|
||||
case GGML_PREC_DEFAULT:
|
||||
case GGML_PREC_F32:
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue