chat_template: do not use std::string for buffer

This commit is contained in:
ngxson 2024-02-16 17:04:01 +01:00
parent bba75c792f
commit 9c4422fbe9
2 changed files with 9 additions and 8 deletions

View file

@ -12542,14 +12542,14 @@ LLAMA_API int32_t llama_chat_apply_template(
if (custom_template == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
current_template.resize(2048); // longest known template is about 1200 bytes
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), &(*current_template.begin()), current_template.size());
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), current_template.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
current_template = "<|im_start|>"; // see llama_chat_apply_template_internal
} else {
current_template.resize(res);
current_template = std::string(model_template.data(), model_template.size());
}
}
// format the conversation to string

View file

@ -34,11 +34,11 @@ int main(void) {
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
};
std::string formatted_chat;
std::vector<char> formatted_chat(1024);
int32_t res;
// test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, &(*formatted_chat.begin()), formatted_chat.size());
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) {
@ -51,13 +51,14 @@ int main(void) {
conversation,
message_count,
true,
&(*formatted_chat.begin()),
formatted_chat.data(),
formatted_chat.size()
);
formatted_chat.resize(res);
std::cout << formatted_chat << "\n-------------------------\n";
std::string output(formatted_chat.data(), formatted_chat.size());
std::cout << output << "\n-------------------------\n";
// expect the "formatted_chat" to contain pre-defined strings
assert(formatted_chat.find(substr) != std::string::npos);
assert(output.find(substr) != std::string::npos);
}
return 0;
}