tool-call
: fix llama_chat_apply_template signature / test-chat-template
This commit is contained in:
parent
97d0620968
commit
e983c9d0de
4 changed files with 16 additions and 13 deletions
|
@ -1521,7 +1521,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
const std::vector<llama_chat_msg> & msgs,
|
const std::vector<llama_chat_msg> & msgs,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const std::string & tools,
|
const char * tools,
|
||||||
const char * bos_token,
|
const char * bos_token,
|
||||||
const char * eos_token) {
|
const char * eos_token) {
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
|
@ -1536,7 +1536,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools.empty() ? nullptr : tools.data(), bos_token, eos_token);
|
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
||||||
|
|
||||||
// error: chat template is not supported
|
// error: chat template is not supported
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
|
@ -1546,7 +1546,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
throw std::runtime_error("this custom template is not supported");
|
throw std::runtime_error("this custom template is not supported");
|
||||||
} else {
|
} else {
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
|
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
||||||
fallback = true;
|
fallback = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1557,7 +1557,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
res = llama_chat_apply_template(
|
res = llama_chat_apply_template(
|
||||||
fallback ? nullptr : model,
|
fallback ? nullptr : model,
|
||||||
fallback ? "chatml" : ptr_tmpl,
|
fallback ? "chatml" : ptr_tmpl,
|
||||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
|
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
std::string formatted_chat(buf.data(), res);
|
||||||
|
@ -1570,11 +1570,11 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
const llama_chat_msg & new_msg,
|
const llama_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const std::string & tools,
|
const char * tools,
|
||||||
const char * bos_token,
|
const char * bos_token,
|
||||||
const char * eos_token) {
|
const char * eos_token) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, bos_token, eos_token);
|
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token);
|
||||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
|
@ -1582,7 +1582,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, bos_token, eos_token);
|
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
|
|
@ -493,7 +493,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
const std::vector<llama_chat_msg> & chat,
|
const std::vector<llama_chat_msg> & chat,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja = false,
|
bool use_jinja = false,
|
||||||
const std::string & tools = "",
|
const char * tools = nullptr,
|
||||||
const char * bos_token = nullptr,
|
const char * bos_token = nullptr,
|
||||||
const char * eos_token = nullptr);
|
const char * eos_token = nullptr);
|
||||||
|
|
||||||
|
@ -504,7 +504,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
const llama_chat_msg & new_msg,
|
const llama_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja = false,
|
bool use_jinja = false,
|
||||||
const std::string & tools = "",
|
const char * tools = nullptr,
|
||||||
const char * bos_token = nullptr,
|
const char * bos_token = nullptr,
|
||||||
const char * eos_token = nullptr);
|
const char * eos_token = nullptr);
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
chat.emplace_back(std::move(msg));
|
chat.emplace_back(std::move(msg));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? "" : tools.dump());
|
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str());
|
||||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||||
|
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
|
|
|
@ -27,6 +27,8 @@ int main(void) {
|
||||||
{"user", "Another question"},
|
{"user", "Another question"},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string tools = "";
|
||||||
|
|
||||||
std::vector<test_template> templates {
|
std::vector<test_template> templates {
|
||||||
{
|
{
|
||||||
.name = "teknium/OpenHermes-2.5-Mistral-7B",
|
.name = "teknium/OpenHermes-2.5-Mistral-7B",
|
||||||
|
@ -160,7 +162,7 @@ int main(void) {
|
||||||
int32_t res;
|
int32_t res;
|
||||||
|
|
||||||
// test invalid chat template
|
// test invalid chat template
|
||||||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, "<|im_start|>", "<|im_end|>");
|
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, /* tools= */ nullptr, "<|im_start|>", "<|im_end|>");
|
||||||
assert(res < 0);
|
assert(res < 0);
|
||||||
|
|
||||||
for (auto use_jinja : std::vector<bool> { false, true }) {
|
for (auto use_jinja : std::vector<bool> { false, true }) {
|
||||||
|
@ -182,6 +184,7 @@ int main(void) {
|
||||||
formatted_chat.data(),
|
formatted_chat.data(),
|
||||||
formatted_chat.size(),
|
formatted_chat.size(),
|
||||||
use_jinja,
|
use_jinja,
|
||||||
|
tools.empty() ? nullptr : tools.c_str(),
|
||||||
tmpl.bos.c_str(),
|
tmpl.bos.c_str(),
|
||||||
tmpl.eos.c_str()
|
tmpl.eos.c_str()
|
||||||
);
|
);
|
||||||
|
@ -210,7 +213,7 @@ int main(void) {
|
||||||
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
||||||
|
|
||||||
auto fmt_sys = [&](std::string tmpl) {
|
auto fmt_sys = [&](std::string tmpl) {
|
||||||
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, "<|im_start|>", "<|im_end|>");
|
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, /** tools= */ "", "<|im_start|>", "<|im_end|>");
|
||||||
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
|
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
|
@ -229,7 +232,7 @@ int main(void) {
|
||||||
llama_chat_msg new_msg{"user", "How are you"};
|
llama_chat_msg new_msg{"user", "How are you"};
|
||||||
|
|
||||||
auto fmt_single = [&](std::string tmpl) {
|
auto fmt_single = [&](std::string tmpl) {
|
||||||
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, "<|im_start|>", "<|im_end|>");
|
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, /* tools= */ nullptr, "<|im_start|>", "<|im_end|>");
|
||||||
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
|
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue