tool-call: support tool_use variant in llama_chat_template_from_model + drop llama_get_chat_template

This commit is contained in:
ochafik 2024-10-30 10:07:10 +00:00
parent 92c384a5e8
commit 3ebdb2b805
4 changed files with 27 additions and 22 deletions

View file

@ -1719,12 +1719,21 @@ static std::string _llama_model_meta_val_str(const struct llama_model * model, c
minja::chat_template llama_chat_template_from_model( minja::chat_template llama_chat_template_from_model(
const struct llama_model * model, const struct llama_model * model,
const char * chat_template_override) const std::string & chat_template_override,
bool prefer_tool_use)
{ {
// TODO: handle "chatml"? // TODO: handle "chatml"?
std::string chat_template = chat_template_override std::string chat_template = chat_template_override;
? chat_template_override if (chat_template.empty()) {
: _llama_model_meta_val_str(model, "tokenizer.chat_template"); if (prefer_tool_use) {
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
fprintf(stderr, "# tokenizer.chat_template.tool_use: %s\n", chat_template.c_str());
}
if (chat_template.empty()) {
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
fprintf(stderr, "# tokenizer.chat_template: %s\n", chat_template.c_str());
}
}
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
return {std::move(chat_template), bos_token, eos_token}; return {std::move(chat_template), bos_token, eos_token};

View file

@ -529,7 +529,8 @@ std::string common_chat_format_example(const struct llama_model * model,
minja::chat_template llama_chat_template_from_model( minja::chat_template llama_chat_template_from_model(
const struct llama_model * model, const struct llama_model * model,
const char * chat_template_override = nullptr); const std::string & chat_template_override = "",
bool prefer_tool_use = false);
// //
// KV cache utils // KV cache utils

View file

@ -2923,13 +2923,20 @@ int main(int argc, char ** argv) {
}; };
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
auto chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ false);
json data = { json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel }, { "total_slots", ctx_server.params.n_parallel },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) },
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) }, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) },
{ "chat_template", llama_get_chat_template(ctx_server.model) }, { "chat_template", chat_template.source()},
}; };
if (ctx_server.params.use_jinja) {
auto tool_use_chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ true);
if (tool_use_chat_template.source() != chat_template.source()) {
data["chat_template_tool_use"] = tool_use_chat_template.source();
}
}
res_ok(res, data); res_ok(res, data);
}; };
@ -3030,13 +3037,14 @@ int main(int argc, char ** argv) {
return; return;
} }
static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); auto body = json::parse(req.body);
static auto tool_call_style = llama_tool_call_style_detect(chat_template); auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template, /* prefer_tool_use= */ body.contains("tools"));
auto tool_call_style = llama_tool_call_style_detect(chat_template);
LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str()); LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str());
json data; json data;
try { try {
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja); data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, tool_call_style, params.use_jinja);
} catch (const std::exception & e) { } catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
return; return;

View file

@ -93,19 +93,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
return formatted_chat; return formatted_chat;
} }
static std::string llama_get_chat_template(const struct llama_model * model) {
std::string template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 0) {
return "";
} else {
std::vector<char> model_template(res, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size());
}
}
// //
// base64 utils (TODO: move to common in the future) // base64 utils (TODO: move to common in the future)
// //