add causal_attn flag to llama_cparams
This commit is contained in:
parent
2df2834df3
commit
d3085deb2a
3 changed files with 18 additions and 16 deletions
|
@ -65,6 +65,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||||
|
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
llama_set_causal_attn(ctx, false);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
llama_decode(ctx, batch);
|
llama_decode(ctx, batch);
|
||||||
|
@ -131,6 +132,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
|
|
||||||
const llama_model * mdl = llama_get_model(ctx);
|
const llama_model * mdl = llama_get_model(ctx);
|
||||||
llama_token eos_token = llama_token_eos(mdl);
|
llama_token eos_token = llama_token_eos(mdl);
|
||||||
|
|
||||||
|
llama_kv_cache_clear(ctx);
|
||||||
|
llama_set_causal_attn(ctx, true);
|
||||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||||
|
|
||||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
||||||
|
@ -197,11 +201,8 @@ int main(int argc, char * argv[])
|
||||||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||||
|
|
||||||
// create new context - set to embedding mode
|
// create new context - set to embedding mode
|
||||||
llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams);
|
cparams.embeddings = true;
|
||||||
llama_set_embeddings(embd_ctx, true);
|
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
||||||
|
|
||||||
// create new context - default mode is causal
|
|
||||||
llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams);
|
|
||||||
|
|
||||||
// ### Embedding/Representation ###
|
// ### Embedding/Representation ###
|
||||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||||
|
@ -219,8 +220,8 @@ int main(int argc, char * argv[])
|
||||||
};
|
};
|
||||||
|
|
||||||
// No need to add instruction for retrieval documents
|
// No need to add instruction for retrieval documents
|
||||||
std::vector<std::vector<float>> d_rep = encode(embd_ctx, documents, gritlm_instruction(""));
|
std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||||
std::vector<std::vector<float>> q_rep = encode(embd_ctx, queries, gritlm_instruction(instruction));
|
std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||||
|
|
||||||
float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
|
float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
|
||||||
float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
|
float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
|
||||||
|
@ -237,12 +238,10 @@ int main(int argc, char * argv[])
|
||||||
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
||||||
{
|
{
|
||||||
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
||||||
std::string response = generate(causal_ctx, prompt, true);
|
std::string response = generate(ctx, prompt, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_free(embd_ctx);
|
llama_free(ctx);
|
||||||
llama_free(causal_ctx);
|
|
||||||
|
|
||||||
llama_free_model(mdl);
|
llama_free_model(mdl);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
|
|
11
llama.cpp
11
llama.cpp
|
@ -1683,7 +1683,9 @@ struct llama_cparams {
|
||||||
float defrag_thold;
|
float defrag_thold;
|
||||||
|
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
|
bool causal_attn;
|
||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval;
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
|
@ -8030,13 +8032,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(
|
GGML_ASSERT(
|
||||||
(hparams.causal_attn || cparams.embeddings) &&
|
(hparams.causal_attn || !cparams.causal_attn) &&
|
||||||
"non-causal attention with generative models is not supported"
|
"non-causal attention with generative models is not supported"
|
||||||
);
|
);
|
||||||
|
|
||||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||||
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
|
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
|
||||||
if (!cparams.embeddings) {
|
if (cparams.causal_attn) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
@ -12181,6 +12183,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cparams.causal_attn = hparams.causal_attn;
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
|
@ -13169,8 +13172,8 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
|
||||||
ctx->abort_callback_data = abort_callback_data;
|
ctx->abort_callback_data = abort_callback_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
|
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
||||||
ctx->cparams.embeddings = embeddings;
|
ctx->cparams.causal_attn = causal_attn;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_batch llama_batch_get_one(
|
struct llama_batch llama_batch_get_one(
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -643,7 +643,7 @@ extern "C" {
|
||||||
|
|
||||||
// Set whether to use causal attention or not
|
// Set whether to use causal attention or not
|
||||||
// If set to true, the model will only attend to the past tokens
|
// If set to true, the model will only attend to the past tokens
|
||||||
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
||||||
|
|
||||||
// Set abort callback
|
// Set abort callback
|
||||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue