CPU/CUDA: Gemma 2 FlashAttention support (#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This commit is contained in:
parent
8f824ffe8e
commit
e11bd856d5
12 changed files with 319 additions and 79 deletions
|
@ -802,6 +802,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||
if (op->src[0]->ne[0] == 256) {
|
||||
return false;
|
||||
}
|
||||
{
|
||||
float logit_softcap;
|
||||
|
||||
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue