chat_template: do not use std::string for buffer
This commit is contained in:
parent
bba75c792f
commit
9c4422fbe9
2 changed files with 9 additions and 8 deletions
|
@ -12542,14 +12542,14 @@ LLAMA_API int32_t llama_chat_apply_template(
|
|||
if (custom_template == nullptr) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
// load template from model
|
||||
current_template.resize(2048); // longest known template is about 1200 bytes
|
||||
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
||||
std::string template_key = "tokenizer.chat_template";
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), &(*current_template.begin()), current_template.size());
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), current_template.size());
|
||||
if (res < 0) {
|
||||
// worst case: there is no information about template, we will use chatml by default
|
||||
current_template = "<|im_start|>"; // see llama_chat_apply_template_internal
|
||||
} else {
|
||||
current_template.resize(res);
|
||||
current_template = std::string(model_template.data(), model_template.size());
|
||||
}
|
||||
}
|
||||
// format the conversation to string
|
||||
|
|
|
@ -34,11 +34,11 @@ int main(void) {
|
|||
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
||||
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||
};
|
||||
std::string formatted_chat;
|
||||
std::vector<char> formatted_chat(1024);
|
||||
int32_t res;
|
||||
|
||||
// test invalid chat template
|
||||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, &(*formatted_chat.begin()), formatted_chat.size());
|
||||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
|
||||
assert(res < 0);
|
||||
|
||||
for (size_t i = 0; i < templates.size(); i++) {
|
||||
|
@ -51,13 +51,14 @@ int main(void) {
|
|||
conversation,
|
||||
message_count,
|
||||
true,
|
||||
&(*formatted_chat.begin()),
|
||||
formatted_chat.data(),
|
||||
formatted_chat.size()
|
||||
);
|
||||
formatted_chat.resize(res);
|
||||
std::cout << formatted_chat << "\n-------------------------\n";
|
||||
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||
std::cout << output << "\n-------------------------\n";
|
||||
// expect the "formatted_chat" to contain pre-defined strings
|
||||
assert(formatted_chat.find(substr) != std::string::npos);
|
||||
assert(output.find(substr) != std::string::npos);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue