diff --git a/common/common.cpp b/common/common.cpp index 781d35f86..3be74ace3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1719,12 +1719,21 @@ static std::string _llama_model_meta_val_str(const struct llama_model * model, c minja::chat_template llama_chat_template_from_model( const struct llama_model * model, - const char * chat_template_override) + const std::string & chat_template_override, + bool prefer_tool_use) { // TODO: handle "chatml"? - std::string chat_template = chat_template_override - ? chat_template_override - : _llama_model_meta_val_str(model, "tokenizer.chat_template"); + std::string chat_template = chat_template_override; + if (chat_template.empty()) { + if (prefer_tool_use) { + chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); + fprintf(stderr, "# tokenizer.chat_template.tool_use: %s\n", chat_template.c_str()); + } + if (chat_template.empty()) { + chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template"); + fprintf(stderr, "# tokenizer.chat_template: %s\n", chat_template.c_str()); + } + } auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); return {std::move(chat_template), bos_token, eos_token}; diff --git a/common/common.h b/common/common.h index 844afa3f1..971ed2d98 100644 --- a/common/common.h +++ b/common/common.h @@ -529,7 +529,8 @@ std::string common_chat_format_example(const struct llama_model * model, minja::chat_template llama_chat_template_from_model( const struct llama_model * model, - const char * chat_template_override = nullptr); + const std::string & chat_template_override = "", + bool prefer_tool_use = false); // // KV cache utils diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d7bfa0180..411010ddb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2923,13 +2923,20 @@ int main(int argc, char ** argv) { }; const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + auto chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ false); json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params.n_parallel }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) }, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) }, - { "chat_template", llama_get_chat_template(ctx_server.model) }, + { "chat_template", chat_template.source()}, }; + if (ctx_server.params.use_jinja) { + auto tool_use_chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ true); + if (tool_use_chat_template.source() != chat_template.source()) { + data["chat_template_tool_use"] = tool_use_chat_template.source(); + } + } res_ok(res, data); }; @@ -3030,13 +3037,14 @@ int main(int argc, char ** argv) { return; } - static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); - static auto tool_call_style = llama_tool_call_style_detect(chat_template); + auto body = json::parse(req.body); + auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template, /* prefer_tool_use= */ body.contains("tools")); + auto tool_call_style = llama_tool_call_style_detect(chat_template); LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str()); json data; try { - data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja); + data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, tool_call_style, params.use_jinja); } catch (const std::exception & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED)); return; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f58e7171a..aa5fbbe7e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -93,19 +93,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri return formatted_chat; } -static std::string llama_get_chat_template(const struct llama_model * model) { - std::string template_key = "tokenizer.chat_template"; - // call with NULL buffer to get the total size of the string - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); - if (res < 0) { - return ""; - } else { - std::vector model_template(res, 0); - llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size()); - } -} - // // base64 utils (TODO: move to common in the future) //