tool-call
: support tool_use variant in llama_chat_template_from_model + drop llama_get_chat_template
This commit is contained in:
parent
92c384a5e8
commit
3ebdb2b805
4 changed files with 27 additions and 22 deletions
|
@ -1719,12 +1719,21 @@ static std::string _llama_model_meta_val_str(const struct llama_model * model, c
|
|||
|
||||
minja::chat_template llama_chat_template_from_model(
|
||||
const struct llama_model * model,
|
||||
const char * chat_template_override)
|
||||
const std::string & chat_template_override,
|
||||
bool prefer_tool_use)
|
||||
{
|
||||
// TODO: handle "chatml"?
|
||||
std::string chat_template = chat_template_override
|
||||
? chat_template_override
|
||||
: _llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
std::string chat_template = chat_template_override;
|
||||
if (chat_template.empty()) {
|
||||
if (prefer_tool_use) {
|
||||
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
|
||||
fprintf(stderr, "# tokenizer.chat_template.tool_use: %s\n", chat_template.c_str());
|
||||
}
|
||||
if (chat_template.empty()) {
|
||||
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
fprintf(stderr, "# tokenizer.chat_template: %s\n", chat_template.c_str());
|
||||
}
|
||||
}
|
||||
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
|
||||
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
|
||||
return {std::move(chat_template), bos_token, eos_token};
|
||||
|
|
|
@ -529,7 +529,8 @@ std::string common_chat_format_example(const struct llama_model * model,
|
|||
|
||||
minja::chat_template llama_chat_template_from_model(
|
||||
const struct llama_model * model,
|
||||
const char * chat_template_override = nullptr);
|
||||
const std::string & chat_template_override = "",
|
||||
bool prefer_tool_use = false);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
|
|
|
@ -2923,13 +2923,20 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
auto chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ false);
|
||||
json data = {
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params.n_parallel },
|
||||
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) },
|
||||
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) },
|
||||
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
||||
{ "chat_template", chat_template.source()},
|
||||
};
|
||||
if (ctx_server.params.use_jinja) {
|
||||
auto tool_use_chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ true);
|
||||
if (tool_use_chat_template.source() != chat_template.source()) {
|
||||
data["chat_template_tool_use"] = tool_use_chat_template.source();
|
||||
}
|
||||
}
|
||||
|
||||
res_ok(res, data);
|
||||
};
|
||||
|
@ -3030,13 +3037,14 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
|
||||
static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
|
||||
static auto tool_call_style = llama_tool_call_style_detect(chat_template);
|
||||
auto body = json::parse(req.body);
|
||||
auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template, /* prefer_tool_use= */ body.contains("tools"));
|
||||
auto tool_call_style = llama_tool_call_style_detect(chat_template);
|
||||
LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str());
|
||||
|
||||
json data;
|
||||
try {
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja);
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, tool_call_style, params.use_jinja);
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
|
|
|
@ -93,19 +93,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|||
return formatted_chat;
|
||||
}
|
||||
|
||||
static std::string llama_get_chat_template(const struct llama_model * model) {
|
||||
std::string template_key = "tokenizer.chat_template";
|
||||
// call with NULL buffer to get the total size of the string
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
||||
if (res < 0) {
|
||||
return "";
|
||||
} else {
|
||||
std::vector<char> model_template(res, 0);
|
||||
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||
return std::string(model_template.data(), model_template.size());
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue