From d3085deb2ac33a9d9132295a45d46d377bdc8da9 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Sat, 9 Mar 2024 22:59:30 -0600 Subject: [PATCH] add causal_attn flag to llama_cparams --- examples/gritlm/gritlm.cpp | 21 ++++++++++----------- llama.cpp | 11 +++++++---- llama.h | 2 +- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 4abd869ba..13aae5472 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -65,6 +65,7 @@ static std::vector> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); + llama_set_causal_attn(ctx, false); // run model llama_decode(ctx, batch); @@ -131,6 +132,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo const llama_model * mdl = llama_get_model(ctx); llama_token eos_token = llama_token_eos(mdl); + + llama_kv_cache_clear(ctx); + llama_set_causal_attn(ctx, true); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); @@ -197,11 +201,8 @@ int main(int argc, char * argv[]) llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); // create new context - set to embedding mode - llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams); - llama_set_embeddings(embd_ctx, true); - - // create new context - default mode is causal - llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams); + cparams.embeddings = true; + llama_context * ctx = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -219,8 +220,8 @@ int main(int argc, char * argv[]) }; // No need to add instruction for retrieval documents - std::vector> d_rep = encode(embd_ctx, documents, gritlm_instruction("")); - std::vector> q_rep = encode(embd_ctx, queries, gritlm_instruction(instruction)); + std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); + std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]); float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]); @@ -237,12 +238,10 @@ int main(int argc, char * argv[]) // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(causal_ctx, prompt, true); + std::string response = generate(ctx, prompt, true); } - llama_free(embd_ctx); - llama_free(causal_ctx); - + llama_free(ctx); llama_free_model(mdl); llama_backend_free(); diff --git a/llama.cpp b/llama.cpp index e183a09c0..6cff89fbb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1683,7 +1683,9 @@ struct llama_cparams { float defrag_thold; bool embeddings; + bool causal_attn; bool offload_kqv; + enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; @@ -8030,13 +8032,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } GGML_ASSERT( - (hparams.causal_attn || cparams.embeddings) && + (hparams.causal_attn || !cparams.causal_attn) && "non-causal attention with generative models is not supported" ); // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. // But if cparams.embeddings is set, the attention will be non-causal nonetheless. - if (!cparams.embeddings) { + if (cparams.causal_attn) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -12181,6 +12183,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; @@ -13169,8 +13172,8 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } -void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { - ctx->cparams.embeddings = embeddings; +void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { + ctx->cparams.causal_attn = causal_attn; } struct llama_batch llama_batch_get_one( diff --git a/llama.h b/llama.h index 0fe7b0105..5377e7f19 100644 --- a/llama.h +++ b/llama.h @@ -643,7 +643,7 @@ extern "C" { // Set whether to use causal attention or not // If set to true, the model will only attend to the past tokens - LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);