From 859e70899a25abcd08e190890da05c201d9bb72b Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Wed, 22 Mar 2023 17:52:46 -0700 Subject: [PATCH] start doing the instructions but not finished. This probably doesnt compile --- llama.cpp | 16 ++++++++++++++-- llama.h | 5 +++++ main.cpp | 1 + 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 77680f46e..111801e89 100644 --- a/llama.cpp +++ b/llama.cpp @@ -101,6 +101,8 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; + // input embedding (1-dimensional array: [n_embd]) + std::vector embedding; bool logits_all = false; }; @@ -112,6 +114,7 @@ struct llama_context_params llama_context_default_params() { /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, + /*.embedding =*/ false, }; return result; @@ -127,7 +130,8 @@ static bool llama_model_load( int n_ctx, int n_parts, ggml_type memory_type, - bool vocab_only) { + bool vocab_only, + bool embedding) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); const int64_t t_start_us = ggml_time_us(); @@ -594,6 +598,10 @@ static bool llama_model_load( lctx.logits.reserve(lctx.model.hparams.n_ctx); + if (embedding){ + lctx.embedding.reserve(lctx.model.hparams.n_embd); + } + lctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -1433,7 +1441,7 @@ struct llama_context * llama_init_from_file( ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) { + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) { fprintf(stderr, "%s: failed to load model\n", __func__); delete ctx; return nullptr; @@ -1508,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { if (token >= llama_n_vocab(ctx)) { return nullptr; diff --git a/llama.h b/llama.h index 0fc5438a8..393a896eb 100644 --- a/llama.h +++ b/llama.h @@ -53,6 +53,7 @@ extern "C" { bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights + bool embedding; // embedding mode only }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -109,6 +110,10 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx) + // Token Id -> String. Uses the vocabulary in the provided context LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); diff --git a/main.cpp b/main.cpp index 44b4cec28..8a639660c 100644 --- a/main.cpp +++ b/main.cpp @@ -199,6 +199,7 @@ int main(int argc, char ** argv) { lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; ctx = llama_init_from_file(params.model.c_str(), lparams);