llama : add support for GritLM (#5959)
* add gritlm example * gritlm results match * tabs to spaces * comment out debug printing * rebase to new embed * gritlm embeddings are back babeee * add to gitignore * allow to toggle embedding mode * Clean-up GritLM sample code. * Fix types. * Flush stdout and output ending newline if streaming. * mostly style fixes; correct KQ_mask comment * add causal_attn flag to llama_cparams * gritml : minor * llama : minor --------- Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
2960eae847
commit
bcebd7dbf6
7 changed files with 267 additions and 4 deletions
25
llama.cpp
25
llama.cpp
|
@ -1744,6 +1744,7 @@ struct llama_cparams {
|
|||
float defrag_thold;
|
||||
|
||||
bool embeddings;
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
@ -3939,6 +3940,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: causal attm = %d\n", __func__, hparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
|
||||
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
|
||||
|
@ -8532,7 +8534,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (hparams.causal_attn) {
|
||||
GGML_ASSERT(
|
||||
(hparams.causal_attn || !cparams.causal_attn) &&
|
||||
"non-causal attention with generative models is not supported"
|
||||
);
|
||||
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
@ -8560,8 +8568,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
||||
// when using kv cache, the mask needs to match the kv cache size
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
|
||||
|
@ -8580,7 +8589,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
|
||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12733,6 +12746,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
cparams.causal_attn = hparams.causal_attn;
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
|
@ -13767,6 +13782,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
|
|||
ctx->abort_callback_data = abort_callback_data;
|
||||
}
|
||||
|
||||
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
||||
ctx->cparams.causal_attn = causal_attn;
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue