ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-01-31 19:17:16 +02:00
parent 2ddc9bbef1
commit 8ad92dc1ec
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 79 additions and 62 deletions

12
ggml.h
View file

@ -1646,11 +1646,13 @@ extern "C" {
struct ggml_tensor * v,
bool masked);
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch, 1, 1]
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
#define GGML_KQ_MASK_PAD 32
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,