Add embedding mode with arg flag. Currently working (#282)

* working but ugly

* add arg flag, not working on embedding mode

* typo

* Working! Thanks to @nullhook

* make params argument instead of hardcoded boolean. remove useless time check

* start doing the instructions but not finished. This probably doesnt compile

* Embeddings extraction support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Luciano 2023-03-24 08:05:13 -07:00 committed by GitHub
parent b6b268d441
commit 8d4a855c24
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 10 deletions

View file

@ -53,6 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool embedding; // embedding mode only
};
LLAMA_API struct llama_context_params llama_context_default_params();
@ -108,6 +109,10 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);