diff --git a/ggml-metal.m b/ggml-metal.m index ddde60293..18ce5b88a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: - case GGML_OP_FLASH_ATTN_EXT: return true; + case GGML_OP_FLASH_ATTN_EXT: + return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction &&