From 037af53bfa603770fd93c91508df4379e4b7d6e2 Mon Sep 17 00:00:00 2001 From: tristandruyen Date: Thu, 23 May 2024 13:40:53 +0200 Subject: [PATCH] Fix phi3 jinja test templates & match by <|end|> --- llama.cpp | 2 +- tests/test-chat-template.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 36f030682..ca5a939dd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17628,7 +17628,7 @@ static int32_t llama_chat_apply_template_internal( } } // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("+ ' '") != std::string::npos )) { + } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos )) { // Phi 3 for (auto message : chat) { std::string role(message->role); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cbc088ade..8a9f5dc23 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -50,9 +50,9 @@ int main(void) { // Llama-3 "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", // Phi-3 - "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + ' ' + message['content'] + '<|end|> ' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> ' }}{% else %}{{ eos_token }}{% endif %}" + "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|> ' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> ' }}{% else %}{{ eos_token }}{% endif %}" // Phi-3 New - "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + ' ' + message['content'] + '<|end|>' + ' ' + '<|assistant|>' + ' '}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + ' '}}{% endif %}{% endfor %}" + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B