From 3839a08ceed7f36093e7fdf50d4dd780c0c4c667 Mon Sep 17 00:00:00 2001 From: Thomas Antony Date: Thu, 16 Mar 2023 21:03:23 -0700 Subject: [PATCH] Refactor llama.cpp and llama.h --- llama.cpp | 410 ++++++++++++++++++++++++++++++++---------------------- llama.h | 32 ++++- 2 files changed, 267 insertions(+), 175 deletions(-) diff --git a/llama.cpp b/llama.cpp index 65973eb46..bab9da93c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -72,33 +72,36 @@ struct llama_model { }; struct llama_state { - int64_t t_sample_us = 0; - int64_t t_predict_us = 0; + // Timers + struct timing { + int64_t t_load_us = 0; - std::vector logits; + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + } timing; - mutable std::mt19937 rng; + // Random number generator + std::mt19937 rng{}; + // Tokens std::vector embd{}; + std::vector embd_inp{}; + std::vector last_n_tokens{}; + // Logits from inference + std::vector logits{}; + + // Counters int input_consumed = 0; - std::vector embd_inp; - std::vector last_n_tokens; int remaining_tokens = 0; int n_past = 0; size_t mem_per_token = 0; - bool is_initialized = false; - llama_state() {} - bool has_more_input() const { - return input_consumed < embd_inp.size(); - } + // Flag set after initialization + bool is_initialized = false; }; struct llama_context { - int64_t t_load_us = 0; - int64_t t_start_us = 0; - ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16) llama_model model{}; @@ -111,8 +114,6 @@ struct llama_context llama_context() = default; // constructor llama_context(llama_model&& model, gpt_vocab&& vocab, const gpt_params& params): - t_load_us(0), - t_start_us(0), wtype(ggml_type::GGML_TYPE_F16), model(std::move(model)), vocab(std::move(vocab)), @@ -125,6 +126,7 @@ struct llama_context } }; +/* Original code by @ggerganov */ // load the model's weights from a file bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); @@ -806,93 +808,6 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna return true; } - - -/* External API */ - -const std::vector& llama_context_get_embedding(const llama_context& ctx) { - return ctx.state->embd; -} -gpt_vocab& llama_context_get_vocab(llama_context& ctx) { - return ctx.vocab; -} -bool llama_context_is_finished(const llama_context& ctx) -{ - return ctx.state->remaining_tokens <= 0; -} -const std::vector llama_tokenize_text(const llama_context& ctx, const std::string& text) { - return llama_tokenize(ctx.vocab, text, true); -} -const std::vector& llama_context_get_last_n_tokens(const llama_context& ctx) { - return ctx.state->last_n_tokens; -} -llama_context* llama_init_from_params(const gpt_params& params) { - llama_model model{}; - gpt_vocab vocab{}; - - // Compute time taken to load model - const int64_t t_start = ggml_time_us(); - bool ret = llama_model_load(params.model, model, vocab, 1024); - const int64_t t_end = ggml_time_us(); - if(!ret) - { - return nullptr; - } - llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params); - ctx->t_load_us = t_end - t_start; - return ctx; -} -void llama_free_context(llama_context* ctx) { - delete ctx; -} - -const char * llama_print_system_info(void) { - static std::string s; - - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; - - return s.c_str(); -} - -void llama_print_context_info(const llama_context& ctx) -{ - const gpt_params& params = ctx.params; - const std::vector& embd_inp = ctx.state->embd_inp; - { - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); - } - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str()); - } - fprintf(stderr, "\n"); -} - -void llama_print_end_stats(const llama_context& ctx) -{ - const llama_state& state = *ctx.state; - fprintf(stderr, "\n\n"); - fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx.t_load_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.t_sample_us/1000.0f); - fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.t_predict_us/1000.0f, state.t_predict_us/1000.0f/state.n_past); -} // evaluate the transformer // // - model: the model @@ -1129,25 +1044,56 @@ bool llama_eval( return true; } -bool llama_update_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing) { +const char * llama_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +/* External API */ +/// @brief Initialize the context from a set of parameters +/// @param params +llama_context* llama_init_from_params(const gpt_params& params) { + llama_model model{}; + gpt_vocab vocab{}; + + // Compute time taken to load model + const int64_t t_start = ggml_time_us(); + bool ret = llama_model_load(params.model, model, vocab, 1024); + const int64_t t_end = ggml_time_us(); + if(!ret) + { + return nullptr; + } + llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params); + ctx->state->timing.t_load_us = t_end - t_start; + ctx->state->rng = std::mt19937(params.seed); + return ctx; +} +/// @brief Prepare the context for inference +/// @param ctx +bool llama_prepare_context(llama_context& ctx) +{ llama_state& state = *ctx.state; llama_model& model = ctx.model; - const gpt_params& params = ctx.params; - - if (clear_existing) { - state.embd.clear(); - state.input_consumed = 0; - state.embd_inp.clear(); - state.last_n_tokens.clear(); - state.remaining_tokens = 0; - state.n_past = 0; - } - - std::vector line_inp = llama_tokenize_text(ctx, text); - state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end()); + gpt_params& params = ctx.params; int n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) state.embd_inp.size()); - state.remaining_tokens = n_predict; + params.n_predict = n_predict; // determine the required inference memory per token: state.mem_per_token = 0; @@ -1160,10 +1106,76 @@ bool llama_update_context_with_prompt(llama_context& ctx, const std::string& tex int last_n_size = params.repeat_last_n; state.last_n_tokens = std::vector(last_n_size); std::fill(state.last_n_tokens.begin(), state.last_n_tokens.end(), 0); - state.is_initialized = true; + state.remaining_tokens = params.n_predict; + state.input_consumed = 0; return true; } +/// @brief Free the context +void llama_free_context(llama_context* ctx) { + delete ctx; +} + +/* Getters and setters */ +/// @brief Get the embedding vector for the last token +const std::vector& llama_context_get_embedding(const llama_context& ctx) { + return ctx.state->embd; +} +/// @brief Get the vector for the last token +/// @param ctx +gpt_vocab& llama_context_get_vocab(llama_context& ctx) { + return ctx.vocab; +} +/// @brief Is the context finished? +/// @param ctx +bool llama_context_is_finished(const llama_context& ctx) +{ + return ctx.state->remaining_tokens <= 0; +} +/// @brief Is the context finished? +/// @param ctx +void llama_reset_remaining_tokens(const llama_context& ctx) +{ + ctx.state->remaining_tokens = ctx.params.n_predict; +} + +/// @brief Tokenize a text into a vector of ids +/// @param ctx +/// @param text +const std::vector llama_tokenize_text(const llama_context& ctx, const std::string& text) { + // Make sure that the "beginning of string" token is not prefixed to the text + return llama_tokenize(ctx.vocab, text, false); +} +const std::vector& llama_context_get_last_n_tokens(const llama_context& ctx) { + return ctx.state->last_n_tokens; +} + +/// @brief Adds the "beginning of string" token to the model input +/// @param ctx +void llama_add_bos(llama_context& ctx){ + // Add the "bos" token into the model input + llama_state& state = *ctx.state; + llama_model& model = ctx.model; + const gpt_params& params = ctx.params; + + const gpt_vocab::id bos_token = 1; + state.embd_inp.push_back(bos_token); +} + +/// @brief Updates the context and appends new input text +/// @param ctx +/// @param text +void llama_update_input(llama_context& ctx, const std::string& text) +{ + llama_state& state = *ctx.state; + llama_model& model = ctx.model; + const gpt_params& params = ctx.params; + + std::vector line_inp = llama_tokenize_text(ctx, text); + + state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end()); + state.remaining_tokens -= line_inp.size(); +} /// @brief Ingests a batch of input tokens into the context /// @param ctx @@ -1173,7 +1185,7 @@ void llama_ingest_input_batch(llama_context& ctx) const gpt_params& params = ctx.params; // Copy at most n_batch elements from embd_inp to embd - size_t num_copied = std::min((size_t) params.n_batch, state.embd_inp.size() - state.input_consumed); + size_t num_copied = std::min((size_t) params.n_batch+1, state.embd_inp.size() - state.input_consumed); std::copy(state.embd_inp.begin() + state.input_consumed, state.embd_inp.begin() + state.input_consumed + num_copied, std::back_inserter(state.embd)); @@ -1185,21 +1197,60 @@ void llama_ingest_input_batch(llama_context& ctx) state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end()); } -/// @brief Run the prediction step on ctx.embd and store result in ctx.state.logits +/// @brief Returns true if there is unconsumed input in the context /// @param ctx -/// @return -bool llama_predict(llama_context& ctx){ - const int64_t t_start_us = ggml_time_us(); +bool llama_has_unconsumed_input(llama_context& ctx) +{ + llama_state& state = *ctx.state; + return state.input_consumed < state.embd_inp.size(); +} + +/// @brief Ingest all input (in multiple batches) into model and run call predict() +/// @param ctx +bool llama_ingest_all_pending_input(llama_context& ctx, bool print_tokens) +{ + llama_state& state = *ctx.state; + const std::vector& embd = state.embd; + gpt_vocab& vocab = ctx.vocab; + + if(!state.is_initialized) + { + fprintf(stderr, "Context must be initialized before ingesting input"); + return false; + } + + // ingest the tokens into the model one batch at a time + while (llama_has_unconsumed_input(ctx)) + { + llama_ingest_input_batch(ctx); + if (print_tokens) { + std::string s = llama_tokens_to_string(vocab, embd); + printf("%s", s.c_str()); + fflush(stdout); + } + llama_eval_model(ctx); + } + return true; +} +/// @brief Evaluate the model with the current input batch +/// @param ctx +bool llama_eval_model(llama_context& ctx) +{ llama_state& state = *ctx.state; llama_model& model = ctx.model; const gpt_params& params = ctx.params; - if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) { - fprintf(stderr, "Failed to predict\n"); - return false; - } + if (state.embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); - state.t_predict_us += ggml_time_us() - t_start_us; + if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) { + fprintf(stderr, "Failed to predict\n"); + return false; + } + state.timing.t_predict_us += ggml_time_us() - t_start_us; + } + state.n_past += state.embd.size(); + state.embd.clear(); return true; } /// @brief Sample a token from the logits @@ -1229,37 +1280,13 @@ gpt_vocab::id llama_sample_token(llama_context& ctx) state.last_n_tokens.erase(state.last_n_tokens.begin()); state.last_n_tokens.push_back(id); - state.t_sample_us += ggml_time_us() - t_start_sample_us; + state.timing.t_sample_us += ggml_time_us() - t_start_sample_us; } return id; } -/// @brief Ingest all input (in multiple batches) into model and run call predict() -/// @param ctx -bool llama_ingest_input(llama_context& ctx, const std::string& text, bool clear_existing) -{ - llama_state& state = *ctx.state; - - // Initialize context, tokenize text and clear existing state if necessary - if(!state.is_initialized && !llama_update_context_with_prompt(ctx, text, clear_existing)) - { - return false; - } - - // ingest the tokens into the model one batch at a time - while (state.has_more_input()) - { - llama_ingest_input_batch(ctx); - if (state.embd.size() >= 0) { - if(!llama_predict(ctx)) - { - return false; - }; - } - state.n_past += state.embd.size(); - state.embd.clear(); - } - return true; -} +/// @brief Run inference for one token and return the token id +/// @param ctx +/// @param id bool llama_infer(llama_context& ctx, gpt_vocab::id& id) { llama_state& state = *ctx.state; @@ -1275,25 +1302,68 @@ bool llama_infer(llama_context& ctx, gpt_vocab::id& id) { return false; } - // Do prediction if we have enough tokens - if (state.embd.size() > 0) { - if(!llama_predict(ctx)) - { - return false; - } - } - // sample a token + // Already predicted, so we just need to sample + // sample a token id = llama_sample_token(ctx); + // add it to the context state.embd.push_back(id); - state.n_past += 1; // decrement remaining sampling budget --state.remaining_tokens; - // end of text token - if (state.embd.back() == 2) { - state.remaining_tokens = 0; - } return true; } +/// @brief Run inference for one token and return the token as a string +/// @param ctx +/// @param output +/// @param is_end_of_text +bool llama_infer(llama_context& ctx, std::string& output, bool& is_end_of_text) { + // Call overloaded llama_infer and convert to string before returning + gpt_vocab::id id_int; + is_end_of_text = false; + if(!llama_infer(ctx, id_int)){ + return false; + } + + // Pass through the "end of text" token to the user + is_end_of_text = (id_int == EOS_TOKEN_ID); + + // Make sure to pass in the newly generated token to the model as well + llama_eval_model(ctx); + output = ctx.vocab.id_to_token.at(id_int); + return true; +} +bool llama_is_anti_prompt_present(llama_context& ctx, const std::vector& antiprompt_inp) +{ + llama_state& state = *ctx.state; + return std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), state.last_n_tokens.rbegin()); +} + +void llama_print_startup_stats(const llama_context& ctx) +{ + const gpt_params& params = ctx.params; + const std::vector& embd_inp = ctx.state->embd_inp; + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str()); + } + fprintf(stderr, "\n"); +} + +void llama_print_end_stats(const llama_context& ctx) +{ + const llama_state& state = *ctx.state; + fprintf(stderr, "\n\n"); + fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, state.timing.t_load_us/1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.timing.t_sample_us/1000.0f); + fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.timing.t_predict_us/1000.0f, state.timing.t_predict_us/1000.0f/state.n_past); +} diff --git a/llama.h b/llama.h index b9cdeeecd..adbfed173 100644 --- a/llama.h +++ b/llama.h @@ -22,7 +22,7 @@ # define LLAMA_API #endif - +static const int EOS_TOKEN_ID = 2; // default hparams (LLaMA 7B) @@ -41,13 +41,35 @@ struct llama_context; // Startup llama_context* llama_init_from_params(const gpt_params& params); +bool llama_prepare_context(llama_context& ctx); // Input processing and inference -bool llama_ingest_input(llama_context& ctx, const std::string& text, bool clear_existing = true); -bool llama_context_is_finished(const llama_context& ctx); -bool llama_update_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing = true); +// Tokenize text (never adds BOS) const std::vector llama_tokenize_text(const llama_context& ctx, const std::string& text); +// Queues up a BOS token to the model input +void llama_add_bos(llama_context& ctx); +// Queues up input text to the model input +void llama_update_input(llama_context& ctx, const std::string& text); +// Ingests input previously added using llama_update_input() +void llama_ingest_input_batch(llama_context& ctx); +// Ingests all input previously added using llama_update_input() in multiple batches +// Batch size is determined by gpt_params::n_predict +bool llama_ingest_all_pending_input(llama_context& ctx, bool print_tokens = false); +// Checks if the model has unconsumed input to be ingested using llama_ingest_input_batch() +bool llama_has_unconsumed_input(llama_context& ctx); +// Checks if the model has an anti-prompt present its most recent output +bool llama_is_anti_prompt_present(llama_context& ctx, const std::vector& antiprompt_inp); + +// Evaluate the model on a batch of input. Must call llama_ingest_input_batch() first. +bool llama_eval_model(llama_context& ctx); +// Checks if the model has finished generating output (i.e. has generated an EOS token or remaining_tokens == 0) +bool llama_context_is_finished(const llama_context& ctx); +// Resets the remaining_tokens counter to the value specified in the gpt_params +void llama_reset_remaining_tokens(const llama_context& ctx); + +// Overloaded functions to run inference and return either the model output or the decoded text bool llama_infer(llama_context& ctx, gpt_vocab::id& model_output); +bool llama_infer(llama_context& ctx, std::string& output, bool& is_end_of_text); // Teardown void llama_free_context(llama_context* ctx); @@ -61,5 +83,5 @@ const std::vector& llama_context_get_last_n_tokens(const llama_co bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype); // Stats -void llama_print_context_info(const llama_context& ctx); +void llama_print_startup_stats(const llama_context& ctx); void llama_print_end_stats(const llama_context& ctx);