take out attention_type; add in llama_set_embeddings

This commit is contained in:
Douglas Hanley 2024-06-06 15:11:25 -05:00
parent d4e6972f60
commit 8093253b41
5 changed files with 19 additions and 39 deletions

View file

@ -546,17 +546,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
else { invalid_param = true; } else { invalid_param = true; }
return true; return true;
} }
if (arg == "--attention") {
if (++i >= argc) {
invalid_param = true;
return true;
}
std::string value(argv[i]);
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
else { invalid_param = true; }
return true;
}
if (arg == "--defrag-thold" || arg == "-dt") { if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -2460,7 +2449,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type; cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval; cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_eval_user_data = params.cb_eval_user_data;

View file

@ -94,7 +94,6 @@ struct gpt_params {
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type
// // sampling parameters // // sampling parameters
struct llama_sampling_params sparams; struct llama_sampling_params sparams;

View file

@ -44,6 +44,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);
// run model // run model
llama_decode(ctx, batch); llama_decode(ctx, batch);
@ -97,6 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_token eos_token = llama_token_eos(mdl); llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@ -165,13 +170,7 @@ int main(int argc, char * argv[]) {
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create generation context // create generation context
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams); llama_context * ctx = llama_new_context_with_model(mdl, cparams);
// create embedding context
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ### // ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic // samples taken from: https://github.com/ContextualAI/gritlm#basic
@ -189,8 +188,8 @@ int main(int argc, char * argv[]) {
}; };
// No need to add instruction for retrieval documents // No need to add instruction for retrieval documents
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction("")); const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction)); const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(mdl); const int n_embd = llama_n_embd(mdl);
@ -209,11 +208,10 @@ int main(int argc, char * argv[]) {
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{ {
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx_gen, prompt, true); std::string response = generate(ctx, prompt, true);
} }
llama_free(ctx_gen); llama_free(ctx);
llama_free(ctx_emb);
llama_free_model(mdl); llama_free_model(mdl);
llama_backend_free(); llama_backend_free();

View file

@ -15931,7 +15931,6 @@ struct llama_context_params llama_context_default_params() {
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f, /*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f, /*.yarn_ext_factor =*/ -1.0f,
@ -16173,12 +16172,7 @@ struct llama_context * llama_new_context_with_model(
} }
cparams.yarn_attn_factor *= hparams.rope_attn_factor; cparams.yarn_attn_factor *= hparams.rope_attn_factor;
cparams.causal_attn = hparams.causal_attn;
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
cparams.causal_attn = hparams.causal_attn;
} else {
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@ -17914,6 +17908,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data; ctx->abort_callback_data = abort_callback_data;
} }
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn; ctx->cparams.causal_attn = causal_attn;
} }

11
llama.h
View file

@ -177,12 +177,6 @@ extern "C" {
LLAMA_POOLING_TYPE_LAST = 3, LLAMA_POOLING_TYPE_LAST = 3,
}; };
enum llama_attention_type {
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
LLAMA_ATTENTION_TYPE_NONCAUSAL = 1,
};
enum llama_split_mode { enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@ -300,7 +294,6 @@ extern "C" {
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // causal, non-causal, or unspecified
// ref: https://github.com/ggerganov/llama.cpp/pull/2054 // ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_base; // RoPE base frequency, 0 = from model
@ -793,6 +786,10 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token). // Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx); LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
// Set whether the model is in embeddings model or not
// If true, embeddings will be returned but logits will not
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set whether to use causal attention or not // Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens // If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);