CPU/CUDA: Gemma 2 FlashAttention support (#8542)

* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check
This commit is contained in:
Johannes Gäßler 2024-08-24 21:34:59 +02:00 committed by GitHub
parent 8f824ffe8e
commit e11bd856d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 319 additions and 79 deletions

View file

@ -802,6 +802,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
if (op->src[0]->ne[0] == 256) {
return false;
}
{
float logit_softcap;
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
if (logit_softcap != 0.0f) {
return false;
}
}
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: