ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-01-31 19:17:16 +02:00
parent 2ddc9bbef1
commit 8ad92dc1ec
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 79 additions and 62 deletions

View file

@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute(
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);