From 98e6651e2f9d1543b764ecf08cc6788e1763a4fc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 29 Nov 2024 16:48:52 +0100 Subject: [PATCH] add more function into llama-cpp.h --- examples/run/run.cpp | 81 ++++++++++++++++++++------------------------ include/llama-cpp.h | 47 +++++++++++++++++++++++-- src/llama.cpp | 61 +++++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 48 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index cac2faefc..511e193ac 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -111,9 +111,9 @@ class ArgumentParser { class LlamaData { public: - llama_model_ptr model; - llama_sampler_ptr sampler; - llama_context_ptr context; + llama_cpp::model model; + llama_cpp::sampler sampler; + llama_cpp::context context; std::vector messages; int init(const Options & opt) { @@ -133,11 +133,11 @@ class LlamaData { private: // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(const std::string & model_path, const int ngl) { + llama_cpp::model initialize_model(const std::string & model_path, const int ngl) { llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; - llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params)); + llama_cpp::model model(llama_cpp::load_model_from_file(model_path, model_params)); if (!model) { fprintf(stderr, "%s: error: unable to load model\n", __func__); } @@ -146,12 +146,12 @@ class LlamaData { } // Initializes the context with the specified parameters - llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) { + llama_cpp::context initialize_context(const llama_cpp::model & model, const int n_ctx) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = n_ctx; ctx_params.n_batch = n_ctx; - llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); + llama_cpp::context context(llama_cpp::new_context_with_model(model, ctx_params)); if (!context) { fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); } @@ -160,8 +160,8 @@ class LlamaData { } // Initializes and configures the sampler - llama_sampler_ptr initialize_sampler() { - llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); + llama_cpp::sampler initialize_sampler() { + llama_cpp::sampler sampler(llama_cpp::sampler_chain_init(llama_sampler_chain_default_params())); llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f)); llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); @@ -179,34 +179,20 @@ static void add_message(const char * role, const std::string & text, LlamaData & owned_content.push_back(std::move(content)); } -// Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const LlamaData & llama_data, std::vector & formatted, const bool append) { - int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), - llama_data.messages.size(), append, formatted.data(), formatted.size()); - if (result > static_cast(formatted.size())) { - formatted.resize(result); - result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), - llama_data.messages.size(), append, formatted.data(), formatted.size()); - } - - return result; -} - // Function to tokenize the prompt -static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt, +static int tokenize_prompt(const llama_cpp::model & model, const std::string & prompt, std::vector & prompt_tokens) { - const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true); - prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, - true) < 0) { - GGML_ABORT("failed to tokenize the prompt\n"); + try { + prompt_tokens = llama_cpp::tokenize(model, prompt, false, true); + return prompt_tokens.size(); + } catch (const std::exception & e) { + fprintf(stderr, "failed to tokenize the prompt: %s\n", e.what()); + return -1; } - - return n_prompt_tokens; } // Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { +static int check_context_size(const llama_cpp::context & ctx, const llama_batch & batch) { const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); if (n_ctx_used + batch.n_tokens > n_ctx) { @@ -219,15 +205,14 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & } // convert the token to a string -static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) { - char buf[256]; - int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - GGML_ABORT("failed to convert token to piece\n"); +static int convert_token_to_string(const llama_cpp::model & model, const llama_token token_id, std::string & piece) { + try { + piece = llama_cpp::token_to_piece(model, token_id, 0, true); + return 0; + } catch (const std::exception & e) { + fprintf(stderr, "failed to convert token to piece: %s\n", e.what()); + return -1; } - - piece = std::string(buf, n); - return 0; } static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { @@ -308,14 +293,20 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, // Helper function to apply the chat template and handle errors static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector & formatted, const bool is_user_input, int & output_length) { - const int new_len = apply_chat_template(llama_data, formatted, is_user_input); - if (new_len < 0) { - fprintf(stderr, "failed to apply the chat template\n"); + try { + std::string res = llama_cpp::chat_apply_template( + llama_data.model, + "", + llama_data.messages, + is_user_input); + output_length = res.size(); + formatted.resize(output_length); + std::memcpy(formatted.data(), res.c_str(), output_length); + return output_length; + } catch (const std::exception & e) { + fprintf(stderr, "failed to apply chat template: %s\n", e.what()); return -1; } - - output_length = new_len; - return 0; } // Helper function to handle user input diff --git a/include/llama-cpp.h b/include/llama-cpp.h index daa04d4d8..e5f9613f5 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -5,9 +5,12 @@ #endif #include +#include #include "llama.h" +namespace llama_cpp { + struct llama_model_deleter { void operator()(llama_model * model) { llama_free_model(model); } }; @@ -20,6 +23,44 @@ struct llama_sampler_deleter { void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } }; -typedef std::unique_ptr llama_model_ptr; -typedef std::unique_ptr llama_context_ptr; -typedef std::unique_ptr llama_sampler_ptr; +typedef std::unique_ptr model; +typedef std::unique_ptr context; +typedef std::unique_ptr sampler; + +inline model load_model_from_file(const std::string & path_model, llama_model_params params) { + return model(llama_load_model_from_file(path_model.c_str(), params)); +} + +inline context new_context_with_model(const model & model, llama_context_params params) { + return context(llama_new_context_with_model(model.get(), params)); +} + +inline sampler sampler_chain_init(llama_sampler_chain_params params) { + return sampler(llama_sampler_chain_init(params)); +} + +std::vector tokenize( + const llama_cpp::model & model, + const std::string & raw_text, + bool add_special, + bool parse_special = false); + +std::string token_to_piece( + const llama_cpp::model & model, + llama_token token, + int32_t lstrip, + bool special); + +std::string detokenize( + const llama_cpp::model & model, + const std::vector & tokens, + bool remove_special, + bool unparse_special); + +std::string chat_apply_template( + const llama_cpp::model & model, + const std::string & tmpl, + const std::vector & chat, + bool add_ass); + +} // namespace llama_cpp diff --git a/src/llama.cpp b/src/llama.cpp index 22b951ba2..3d42c3469 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,3 +1,4 @@ +#include "llama-cpp.h" #include "llama-impl.h" #include "llama-vocab.h" #include "llama-sampling.h" @@ -21818,6 +21819,14 @@ int32_t llama_tokenize( return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special); } +std::vector llama_cpp::tokenize( + const llama_cpp::model & model, + const std::string & raw_text, + bool add_special, + bool parse_special) { + return llama_tokenize_internal(model->vocab, raw_text, add_special, parse_special); +} + int32_t llama_token_to_piece( const struct llama_model * model, llama_token token, @@ -21828,6 +21837,23 @@ int32_t llama_token_to_piece( return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special); } +std::string llama_cpp::token_to_piece( + const llama_cpp::model & model, + llama_token token, + int32_t lstrip, + bool special) { + std::vector buf(64); + int32_t n = llama_token_to_piece_impl(model->vocab, token, buf.data(), buf.size(), lstrip, special); + if (n > (int32_t) buf.size()) { + buf.resize(n); + llama_token_to_piece_impl(model->vocab, token, buf.data(), buf.size(), lstrip, special); + } else if (n < 0) { + // TODO: make special type of expection here + throw std::runtime_error("failed to convert token to piece"); + } + return std::string(buf.data(), n); +} + int32_t llama_detokenize( const struct llama_model * model, const llama_token * tokens, @@ -21839,6 +21865,23 @@ int32_t llama_detokenize( return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special); } +std::string llama_cpp::detokenize( + const llama_cpp::model & model, + const std::vector & tokens, + bool remove_special, + bool unparse_special) { + std::vector buf(1024); + int32_t n = llama_detokenize_impl(model->vocab, tokens.data(), tokens.size(), buf.data(), buf.size(), remove_special, unparse_special); + if (n > (int32_t) buf.size()) { + buf.resize(n); + llama_detokenize_impl(model->vocab, tokens.data(), tokens.size(), buf.data(), buf.size(), remove_special, unparse_special); + } else if (n < 0) { + // TODO: make special type of expection here + throw std::runtime_error("failed to detokenize"); + } + return std::string(buf.data(), n); +} + // // chat templates // @@ -22172,6 +22215,24 @@ int32_t llama_chat_apply_template( return res; } +std::string llama_cpp::chat_apply_template( + const llama_cpp::model & model, + const std::string & tmpl, + const std::vector & chat, + bool add_ass) { + std::vector buf; + const char * tmpl_c = tmpl.empty() ? nullptr : tmpl.c_str(); + int32_t n = llama_chat_apply_template(model.get(), tmpl_c, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + if (n > (int32_t) buf.size()) { + buf.resize(n); + llama_chat_apply_template(model.get(), tmpl_c, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + } else if (n < 0) { + // TODO: make special type of expection here + throw std::runtime_error("failed to format chat template"); + } + return std::string(buf.data(), n); +} + // // sampling //