llama : prep ALiBi support for BERT models
ggml-ci
This commit is contained in:
parent
78d363b0d4
commit
19e8982f51
2 changed files with 9 additions and 2 deletions
1
ggml.c
1
ggml.c
|
@ -5476,6 +5476,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(ggml_is_matrix(mask));
|
GGML_ASSERT(ggml_is_matrix(mask));
|
||||||
|
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -6712,8 +6712,14 @@ struct llm_build_context {
|
||||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * build_inp_KQ_pos() {
|
struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
|
||||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
|
if (causal) {
|
||||||
|
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
|
||||||
|
} else {
|
||||||
|
// TODO: this will be needed for ALiBi-based BERT models
|
||||||
|
// https://github.com/ggerganov/llama.cpp/pull/6826
|
||||||
|
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
|
||||||
|
}
|
||||||
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
|
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
|
||||||
ggml_set_input(lctx.inp_KQ_pos);
|
ggml_set_input(lctx.inp_KQ_pos);
|
||||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
|
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue