chat_template: do not use std::string for buffer

This commit is contained in:
ngxson 2024-02-16 17:04:01 +01:00
parent bba75c792f
commit 9c4422fbe9
2 changed files with 9 additions and 8 deletions

View file

@ -12542,14 +12542,14 @@ LLAMA_API int32_t llama_chat_apply_template(
if (custom_template == nullptr) { if (custom_template == nullptr) {
GGML_ASSERT(model != nullptr); GGML_ASSERT(model != nullptr);
// load template from model // load template from model
current_template.resize(2048); // longest known template is about 1200 bytes std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template"; std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), &(*current_template.begin()), current_template.size()); int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), current_template.size());
if (res < 0) { if (res < 0) {
// worst case: there is no information about template, we will use chatml by default // worst case: there is no information about template, we will use chatml by default
current_template = "<|im_start|>"; // see llama_chat_apply_template_internal current_template = "<|im_start|>"; // see llama_chat_apply_template_internal
} else { } else {
current_template.resize(res); current_template = std::string(model_template.data(), model_template.size());
} }
} }
// format the conversation to string // format the conversation to string

View file

@ -34,11 +34,11 @@ int main(void) {
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]", "</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", "[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
}; };
std::string formatted_chat; std::vector<char> formatted_chat(1024);
int32_t res; int32_t res;
// test invalid chat template // test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, &(*formatted_chat.begin()), formatted_chat.size()); res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0); assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) { for (size_t i = 0; i < templates.size(); i++) {
@ -51,13 +51,14 @@ int main(void) {
conversation, conversation,
message_count, message_count,
true, true,
&(*formatted_chat.begin()), formatted_chat.data(),
formatted_chat.size() formatted_chat.size()
); );
formatted_chat.resize(res); formatted_chat.resize(res);
std::cout << formatted_chat << "\n-------------------------\n"; std::string output(formatted_chat.data(), formatted_chat.size());
std::cout << output << "\n-------------------------\n";
// expect the "formatted_chat" to contain pre-defined strings // expect the "formatted_chat" to contain pre-defined strings
assert(formatted_chat.find(substr) != std::string::npos); assert(output.find(substr) != std::string::npos);
} }
return 0; return 0;
} }