diff --git a/Makefile b/Makefile index 0a2070b53..8f1535b55 100644 --- a/Makefile +++ b/Makefile @@ -862,3 +862,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/llama.cpp b/llama.cpp index 08e7b02b4..b13c23839 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12459,6 +12459,114 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return 0; } +int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass); + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + if (chat_template.find("<|im_start|>") != std::string::npos) { + // chatml template + for (auto message : conversation) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (chat_template.find("[INST]") != std::string::npos) { + // llama2 template and its variants + // [variant] support system message + bool support_system_message = chat_template.find("<>") != std::string::npos; + // [variant] space before + after response + bool space_around_response = chat_template.find("' ' + eos_token") != std::string::npos; + // [variant] add BOS inside history + bool add_bos_inside_history = chat_template.find("bos_token + '[INST]") != std::string::npos; + // [variant] trim spaces from the input message + bool strip_message = chat_template.find("content.strip()") != std::string::npos; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : conversation) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; + is_inside_turn = false; + } + } + // llama2 templates seem to not care about "add_generation_prompt" + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +LLAMA_API int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * custom_template, + const struct llama_chat_message * msg, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length) { + std::string current_template(custom_template == nullptr ? "" : custom_template); + if (custom_template == nullptr) { + GGML_ASSERT(model != nullptr); + // load template from model + current_template.resize(2048); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), &(*current_template.begin()), current_template.size()); + if (res < 0) { + // worst case: there is no information about template, we will use chatml by default + current_template = "<|im_start|>"; // see llama_chat_apply_template_internal + } else { + current_template.resize(res); + } + } + // format the conversation to string + std::vector conversation_vec; + conversation_vec.resize(n_msg); + for (size_t i = 0; i < n_msg; i++) { + conversation_vec[i] = &msg[i]; + } + std::string formatted_chat; + int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass); + if (res < 0) { + return res; + } + strncpy(buf, formatted_chat.c_str(), length); + return res; +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index f4ec6ea63..ff2c47019 100644 --- a/llama.h +++ b/llama.h @@ -304,6 +304,12 @@ extern "C" { int32_t n_eval; }; + // used in chat template + typedef struct llama_chat_message { + const char * role; + const char * content; + } llama_chat_message; + // Helpers for getting default parameters LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); @@ -698,6 +704,22 @@ extern "C" { char * buf, int32_t length); + /// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python. + /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead. + /// @param msg Pointer to a list of multiple llama_chat_message + /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. + /// @return If "tokenize" is set to false, the "buf" must be a string (returned value will be the string length). + /// Otherwise, "buf" must be a list of tokens (returned value will be the number of tokens). + LLAMA_API int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * custom_template, + const struct llama_chat_message * msg, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length); + // // Grammar // diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3e40a78cd..10326d531 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -28,6 +28,7 @@ endfunction() llama_build_and_test_executable(test-quantize-fns.cpp) llama_build_and_test_executable(test-quantize-perf.cpp) llama_build_and_test_executable(test-sampling.cpp) +llama_build_and_test_executable(test-chat-template.cpp) llama_build_executable(test-tokenizer-0-llama.cpp) llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp new file mode 100644 index 000000000..23d02a7d2 --- /dev/null +++ b/tests/test-chat-template.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include + +#undef NDEBUG +#include + +#include "llama.h" + +int main(void) { + llama_chat_message conversation[] = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "Who are you"}, + {"assistant", " I am an assistant "}, + {"user", "Another question"}, + }; + size_t message_count = 6; + std::vector conversation_vec; + conversation_vec.resize(message_count); + for (size_t i = 0; i < message_count; i++) { + conversation_vec[i] = &conversation[i]; + } + std::vector templates = { + // teknium/OpenHermes-2.5-Mistral-7B + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + // mistralai/Mistral-7B-Instruct-v0.2 + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + // TheBloke/FusionNet_34Bx2_MoE-AWQ + "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", + // bofenghuang/vigogne-2-70b-chat + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + }; + std::vector expected_substr = { + "<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", + "[/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + }; + std::string formatted_chat; + int32_t res; + + // test invalid chat template + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, &(*formatted_chat.begin()), formatted_chat.size()); + assert(res < 0); + + for (size_t i = 0; i < templates.size(); i++) { + std::string custom_template = templates[i]; + std::string substr = expected_substr[i]; + formatted_chat.resize(1024); + res = llama_chat_apply_template( + nullptr, + custom_template.c_str(), + conversation, + message_count, + true, + &(*formatted_chat.begin()), + formatted_chat.size() + ); + formatted_chat.resize(res); + std::cout << formatted_chat << "\n-------------------------\n"; + // expect the "formatted_chat" to contain pre-defined strings + assert(formatted_chat.find(substr) != std::string::npos); + } + return 0; +}