llama : prep ALiBi support for BERT models

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-23 17:24:28 +03:00
parent 78d363b0d4
commit 19e8982f51
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 9 additions and 2 deletions

1
ggml.c
View file

@ -5476,6 +5476,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask)); GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]); GGML_ASSERT(mask->ne[1] >= a->ne[1]);
} }

View file

@ -6712,8 +6712,14 @@ struct llm_build_context {
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
} }
struct ggml_tensor * build_inp_KQ_pos() { struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); if (causal) {
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
} else {
// TODO: this will be needed for ALiBi-based BERT models
// https://github.com/ggerganov/llama.cpp/pull/6826
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
}
cb(lctx.inp_KQ_pos, "KQ_pos", -1); cb(lctx.inp_KQ_pos, "KQ_pos", -1);
ggml_set_input(lctx.inp_KQ_pos); ggml_set_input(lctx.inp_KQ_pos);
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;