diff --git a/llama.cpp b/llama.cpp index 199669517..3c992d6f6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12459,7 +12459,10 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return 0; } -int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass); +int32_t llama_chat_apply_template_internal( + const std::string & chat_template, + const std::vector & chat, + std::string & dest, bool add_ass); // trim whitespace from the beginning and end of a string static std::string trim(const std::string & str) { @@ -12476,12 +12479,15 @@ static std::string trim(const std::string & str) { // Simple version of "llama_apply_chat_template" that only works with strings // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. -int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass) { +int32_t llama_chat_apply_template_internal( + const std::string & chat_template, + const std::vector & chat, + std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; if (chat_template.find("<|im_start|>") != std::string::npos) { // chatml template - for (auto message : conversation) { + for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; } if (add_ass) { @@ -12500,7 +12506,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning ss << "[INST] "; - for (auto message : conversation) { + for (auto message : chat) { std::string content = strip_message ? trim(message->content) : message->content; std::string role(message->role); if (!is_inside_turn) { @@ -12524,7 +12530,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t // llama2 templates seem to not care about "add_generation_prompt" } else if (chat_template.find("<|user|>") != std::string::npos) { // zephyr template - for (auto message : conversation) { + for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; } if (add_ass) { @@ -12541,7 +12547,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * custom_template, - const struct llama_chat_message * msg, + const struct llama_chat_message * chat, size_t n_msg, bool add_ass, char * buf, @@ -12560,14 +12566,14 @@ LLAMA_API int32_t llama_chat_apply_template( current_template = std::string(model_template.data(), model_template.size()); } } - // format the conversation to string - std::vector conversation_vec; - conversation_vec.resize(n_msg); + // format the chat to string + std::vector chat_vec; + chat_vec.resize(n_msg); for (size_t i = 0; i < n_msg; i++) { - conversation_vec[i] = &msg[i]; + chat_vec[i] = &chat[i]; } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass); + int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } diff --git a/llama.h b/llama.h index 64140bde2..3e5357685 100644 --- a/llama.h +++ b/llama.h @@ -704,18 +704,20 @@ extern "C" { char * buf, int32_t length); - /// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python. + /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" /// NOTE: This function only support some known jinja templates. It is not a jinja parser. - /// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead. - /// @param msg Pointer to a list of multiple llama_chat_message + /// @param custom_template A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param chat Pointer to a list of multiple llama_chat_message + /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) + /// @param length The size of the allocated buffer /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * custom_template, - const struct llama_chat_message * msg, + const struct llama_chat_message * chat, size_t n_msg, bool add_ass, char * buf,