diff --git a/llama.cpp b/llama.cpp index 111801e89..8a3f8b53f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -101,9 +101,10 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; + bool logits_all = false; + // input embedding (1-dimensional array: [n_embd]) std::vector embedding; - bool logits_all = false; }; struct llama_context_params llama_context_default_params() { @@ -114,7 +115,7 @@ struct llama_context_params llama_context_default_params() { /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, - /*.embedding =*/ false, + /*.embedding =*/ false, }; return result; @@ -130,8 +131,7 @@ static bool llama_model_load( int n_ctx, int n_parts, ggml_type memory_type, - bool vocab_only, - bool embedding) { + bool vocab_only) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); const int64_t t_start_us = ggml_time_us(); @@ -596,29 +596,11 @@ static bool llama_model_load( fin.close(); } - lctx.logits.reserve(lctx.model.hparams.n_ctx); - - if (embedding){ - lctx.embedding.reserve(lctx.model.hparams.n_embd); - } - lctx.t_load_us = ggml_time_us() - t_start_us; return true; } -// Prints the provided embedding vector to stdout -// in a neat format -void display_embedding(const std::vector & embedding_representation){ - fprintf(stdout, "\n[\n"); - for (int j = 0; j < embedding_representation.size()-1 ; j++){ - fprintf(stdout, "%f, ", embedding_representation[j]); - } - fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); - fprintf(stdout, "\n]\n"); -} - - // evaluate the transformer // // - lctx: llama context @@ -631,8 +613,7 @@ static bool llama_eval_internal( const llama_token * tokens, const int n_tokens, const int n_past, - const int n_threads, - const bool embedding_mode = false) { + const int n_threads) { const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -810,6 +791,9 @@ static bool llama_eval_internal( inpL = cur; } + // used at the end to optionally extract the embeddings + struct ggml_tensor * embeddings = NULL; + // norm { inpL = ggml_rms_norm(ctx0, inpL); @@ -818,18 +802,8 @@ static bool llama_eval_internal( inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL); - } - if(embedding_mode){ - // capture input sentence embedding - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); - std::vector embedding_representation; - embedding_representation.resize(n_embd); - memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd); - display_embedding(embedding_representation); - ggml_free(ctx0); - return true; + embeddings = inpL; } // lm_head @@ -852,15 +826,26 @@ static bool llama_eval_internal( //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - auto & logits_out = lctx.logits; + // extract logits + { + auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } + } + + // extract embeddings + if (lctx.embedding.size()) { + auto & embedding_out = lctx.embedding; + + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } if (mem_per_token == 0) { @@ -1441,12 +1426,26 @@ struct llama_context * llama_init_from_file( ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) { + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) { fprintf(stderr, "%s: failed to load model\n", __func__); delete ctx; return nullptr; } + // reserve memory for context buffers + { + const auto & hparams = ctx->model.hparams; + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_ctx); + } + + if (params.embedding){ + ctx->embedding.reserve(hparams.n_embd); + } + } + return ctx; } @@ -1474,9 +1473,8 @@ int llama_eval( const llama_token * tokens, int n_tokens, int n_past, - int n_threads, - bool embedding_mode = false) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) { + int n_threads) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } diff --git a/llama.h b/llama.h index 393a896eb..209b4dbe8 100644 --- a/llama.h +++ b/llama.h @@ -53,7 +53,7 @@ extern "C" { bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights - bool embedding; // embedding mode only + bool embedding; // embedding mode only }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -85,8 +85,7 @@ extern "C" { const llama_token * tokens, int n_tokens, int n_past, - int n_threads, - bool embedding_mode); + int n_threads); // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. @@ -112,7 +111,7 @@ extern "C" { // Get the embeddings for the input // shape: [n_embd] (1-dimensional) - LLAMA_API float * llama_get_embeddings(struct llama_context * ctx) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); diff --git a/main.cpp b/main.cpp index 8a639660c..dd2709f54 100644 --- a/main.cpp +++ b/main.cpp @@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { int end = start + params.n_ctx - 1; std::vector embd(tokens.begin() + start, tokens.begin() + end); auto start_t = std::chrono::high_resolution_clock::now(); - if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) { + if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } @@ -220,7 +220,7 @@ int main(int argc, char ** argv) { // TODO: better way to do that { const std::vector tmp = { 0, 1, 2, 3 }; - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false); + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); } if (params.perplexity) { @@ -302,7 +302,7 @@ int main(int argc, char ** argv) { #endif " - Press Return to return control to LLaMa.\n" " - If you want to submit another line, end your input in '\\'.\n\n"); - is_interacting = params.interactive_start; + is_interacting = params.interactive_start || params.instruct; } int input_consumed = 0; @@ -325,23 +325,29 @@ int main(int argc, char ** argv) { if (params.embedding){ embd = embd_inp; + if (embd.size() > 0) { - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } } + const auto embeddings = llama_get_embeddings(ctx); + + // TODO: print / use the embeddings + if (params.use_color) { printf(ANSI_COLOR_RESET); } + return 0; } while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; }