From 801071ec4f8d275cfff858057754fe5b5e1c2140 Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Sat, 18 Mar 2023 23:34:20 -0700 Subject: [PATCH] add arg flag, not working on embedding mode --- main.cpp | 122 ++++++++++++++++++++++++++++++++---------------------- utils.cpp | 2 + utils.h | 2 +- 3 files changed, 75 insertions(+), 51 deletions(-) diff --git a/main.cpp b/main.cpp index 2adeb5045..4c2f85e23 100644 --- a/main.cpp +++ b/main.cpp @@ -519,6 +519,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab return true; } +// Prints the provided embedding vector to stdout +// in a neat format +void display_embedding(const std::vector & embedding_representation){ + fprintf(stdout, "\n[\n"); + for (int j = 0; j < embedding_representation.size()-1 ; j++){ + fprintf(stdout, "%f, ", embedding_representation[j]); + } + fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); + fprintf(stdout, "\n]\n"); +} + // evaluate the transformer // // - model: the model @@ -535,7 +546,8 @@ bool llama_eval( const int n_past, const std::vector & embd_inp, std::vector & embd_w, - size_t & mem_per_token) { + size_t & mem_per_token, + const bool embeding_mode) { const int N = embd_inp.size(); const auto & hparams = model.hparams; @@ -720,56 +732,52 @@ bool llama_eval( ggml_repeat(ctx0, model.norm, inpL), inpL); } - - // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); - - // capture input sentence embedding - { - std::vector embedding_representation; - embedding_representation.resize(n_embd); - memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 2)), sizeof(float) * n_embd); - fprintf(stdout, "\n[\n"); - for (int j = 0; j < embedding_representation.size()-1 ; j++){ - fprintf(stdout, "%f, ", embedding_representation[j]); + + if(!embeding_mode){ + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.output, inpL); } - fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); - fprintf(stdout, "\n]\n"); + // logits -> probs + //inpL = ggml_soft_max(ctx0, inpL); + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + return true; + } else { + // capture input sentence embedding + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + printf("Compute went ok\n"); + std::vector embedding_representation; + embedding_representation.resize(n_embd); + memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd); + printf("About to display\n"); + display_embedding(embedding_representation); + printf("About to free\n"); + ggml_free(ctx0); + return true; } - - // lm_head - { - inpL = ggml_mul_mat(ctx0, model.output, inpL); - } - - // logits -> probs - //inpL = ggml_soft_max(ctx0, inpL); - - - - //if (n_past%100 == 0) { - // ggml_graph_print (&gf); - // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); - //} - - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - - // return result for just the last token - embd_w.resize(n_vocab); - memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; - } - //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); - - ggml_free(ctx0); - - return true; } static bool is_interacting = false; @@ -906,13 +914,12 @@ int main(int argc, char ** argv) { // determine the required inference memory per token: size_t mem_per_token = 0; - llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, false); int last_n_size = params.repeat_last_n; std::vector last_n_tokens(last_n_size); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - if (params.interactive) { fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -936,12 +943,27 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } + if (params.embedding){ + printf("got right before second call.\n"); + const int64_t t_start_us = ggml_time_us(); //HERE + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) { + fprintf(stderr, "Failed to predict\n"); + return 1; + } + //ggml_free(model.ctx); + + if (params.use_color) { + printf(ANSI_COLOR_RESET); + } + return 0; + } + while (remaining_tokens > 0) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, false)) { fprintf(stderr, "Failed to predict\n"); return 1; } diff --git a/utils.cpp b/utils.cpp index aa3ad1053..20b6a86ce 100644 --- a/utils.cpp +++ b/utils.cpp @@ -53,6 +53,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.model = argv[++i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; + } else if (arg == "--embedding") { + params.embedding = true; } else if (arg == "--interactive-start") { params.interactive = true; params.interactive_start = true; diff --git a/utils.h b/utils.h index 021120b05..dca497d06 100644 --- a/utils.h +++ b/utils.h @@ -31,7 +31,7 @@ struct gpt_params { std::string prompt; bool use_color = false; // use color to distinguish generations and inputs - + bool embedding = false; // get only sentence embedding bool interactive = false; // interactive mode bool interactive_start = false; // reverse prompt immediately std::string antiprompt = ""; // string upon seeing which more user input is prompted