Update test-chat-template.cpp
This commit is contained in:
parent
d1adeb95b7
commit
a54ccb910b
1 changed files with 6 additions and 6 deletions
|
@ -17,12 +17,12 @@ int main(void) {
|
|||
{"assistant", " I am an assistant "},
|
||||
{"user", "Another question"},
|
||||
};
|
||||
struct ChatTemplate {
|
||||
struct TestCase {
|
||||
std::string name;
|
||||
std::string template_str;
|
||||
std::string expected_output;
|
||||
};
|
||||
std::vector<ChatTemplate> templates {
|
||||
std::vector<TestCase> test_cases {
|
||||
{
|
||||
/* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
|
||||
/* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
|
@ -191,11 +191,11 @@ int main(void) {
|
|||
res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
|
||||
assert(res < 0);
|
||||
|
||||
for (const auto & tmpl : templates) {
|
||||
printf("\n\n=== %s ===\n\n", tmpl.name.c_str());
|
||||
for (const auto & test_case : test_cases) {
|
||||
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
|
||||
formatted_chat.resize(1024);
|
||||
res = llama_chat_apply_template(
|
||||
tmpl.template_str.c_str(),
|
||||
test_case.template_str.c_str(),
|
||||
conversation.data(),
|
||||
conversation.size(),
|
||||
true,
|
||||
|
@ -206,7 +206,7 @@ int main(void) {
|
|||
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||
printf("%s\n", output.c_str());
|
||||
printf("-------------------------\n");
|
||||
assert(output == tmpl.expected_output);
|
||||
assert(output == test_case.expected_output);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue