Merge d25cd7f9e4
into b60074f1c2
This commit is contained in:
commit
843d97b1c7
6 changed files with 281 additions and 63 deletions
|
@ -428,6 +428,7 @@ void gpt_params_parse_from_env(gpt_params & params) {
|
||||||
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
|
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
|
||||||
get_env("LLAMA_ARG_HOST", params.hostname);
|
get_env("LLAMA_ARG_HOST", params.hostname);
|
||||||
get_env("LLAMA_ARG_PORT", params.port);
|
get_env("LLAMA_ARG_PORT", params.port);
|
||||||
|
get_env("LLAMA_ARG_TOOL_CALLS", params.enable_tool_calls);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
|
@ -1046,6 +1047,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.lora_init_without_apply = true;
|
params.lora_init_without_apply = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--tool-call" || arg == "--tool-calls") {
|
||||||
|
params.enable_tool_calls = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--control-vector") {
|
if (arg == "--control-vector") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.control_vectors.push_back({ 1.0f, argv[i], });
|
params.control_vectors.push_back({ 1.0f, argv[i], });
|
||||||
|
@ -2036,6 +2041,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||||
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
|
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
|
||||||
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
|
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
|
||||||
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
|
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
|
||||||
|
options.push_back({ "server", " --tool-call(s)", "enable OAI tool calls for chat completion endpoint (default: %s)", params.enable_tool_calls ? "enabled" : "disabled"});
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
options.push_back({ "logging" });
|
options.push_back({ "logging" });
|
||||||
|
@ -2253,6 +2259,10 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool string_contains(std::string haystack, std::string needle) {
|
||||||
|
return haystack.find(needle) != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
@ -3186,6 +3196,19 @@ std::string llama_chat_format_example(const struct llama_model * model,
|
||||||
return llama_chat_apply_template(model, tmpl, msgs, true);
|
return llama_chat_apply_template(model, tmpl, msgs, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_get_chat_template(const struct llama_model * model) {
|
||||||
|
std::string template_key = "tokenizer.chat_template";
|
||||||
|
// call with NULL buffer to get the total size of the string
|
||||||
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
||||||
|
if (res < 0) {
|
||||||
|
return "";
|
||||||
|
} else {
|
||||||
|
std::vector<char> model_template(res, 0);
|
||||||
|
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||||
|
return std::string(model_template.data(), model_template.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -221,6 +221,7 @@ struct gpt_params {
|
||||||
std::string chat_template = "";
|
std::string chat_template = "";
|
||||||
std::string system_prompt = "";
|
std::string system_prompt = "";
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
|
bool enable_tool_calls = false;
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
|
@ -320,6 +321,8 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
|
bool string_contains(std::string haystack, std::string needle);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
@ -428,6 +431,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
std::string llama_chat_format_example(const struct llama_model * model,
|
std::string llama_chat_format_example(const struct llama_model * model,
|
||||||
const std::string & tmpl);
|
const std::string & tmpl);
|
||||||
|
|
||||||
|
// Returns the chat template stored inside the model
|
||||||
|
// (empty string if model does not have built-in chat template)
|
||||||
|
std::string llama_get_chat_template(const struct llama_model * model);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
#include "grammar-parser.h"
|
||||||
|
#include "tool-call.hpp"
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// crash the server in debug mode, otherwise send an http 500 error
|
// crash the server in debug mode, otherwise send an http 500 error
|
||||||
|
@ -157,6 +158,7 @@ struct server_slot {
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
|
||||||
|
|
||||||
bool infill = false;
|
bool infill = false;
|
||||||
bool embedding = false;
|
bool embedding = false;
|
||||||
|
@ -207,6 +209,7 @@ struct server_slot {
|
||||||
infill = false;
|
infill = false;
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
n_past_se = 0;
|
n_past_se = 0;
|
||||||
|
response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
@ -625,6 +628,7 @@ struct server_context {
|
||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||||
|
llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
|
||||||
|
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
|
@ -1217,7 +1221,13 @@ struct server_context {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!incomplete) {
|
if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) {
|
||||||
|
slot.response_state = check_response_state(tool_format, slot.generated_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// if response is tool call, we cannot stream it
|
||||||
|
// instead, we wait for the full response, then extract JSON
|
||||||
|
if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
|
@ -1247,9 +1257,7 @@ struct server_context {
|
||||||
if (slot.params.stream) {
|
if (slot.params.stream) {
|
||||||
send_partial_response(slot, result);
|
send_partial_response(slot, result);
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
|
||||||
if (incomplete) {
|
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1396,6 +1404,10 @@ struct server_context {
|
||||||
{"multimodal", false}
|
{"multimodal", false}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
|
||||||
|
res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send);
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
||||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||||
|
@ -1444,6 +1456,10 @@ struct server_context {
|
||||||
{"timings", slot.get_formated_timings()}
|
{"timings", slot.get_formated_timings()}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
|
||||||
|
res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text);
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
std::vector<completion_token_output> probs;
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
|
@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||||
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
std::string chat_tmpl = ctx_server.params.chat_template.empty()
|
||||||
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
? llama_get_chat_template(ctx_server.model)
|
||||||
if (tlen > 0) {
|
: ctx_server.params.chat_template;
|
||||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
|
||||||
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
|
||||||
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
json data = {
|
json data = {
|
||||||
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
||||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||||
{ "total_slots", ctx_server.params.n_parallel },
|
{ "total_slots", ctx_server.params.n_parallel },
|
||||||
{ "chat_template", curr_tmpl.c_str() }
|
{ "chat_template", chat_tmpl },
|
||||||
};
|
};
|
||||||
|
|
||||||
res.set_content(data.dump(), MIMETYPE_JSON);
|
res.set_content(data.dump(), MIMETYPE_JSON);
|
||||||
|
@ -3056,7 +3067,19 @@ int main(int argc, char ** argv) {
|
||||||
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json body = json::parse(req.body);
|
||||||
|
|
||||||
|
if (body.contains("tools")) {
|
||||||
|
if (ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
|
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
|
||||||
|
body.erase(body.find("tools"));
|
||||||
|
} else {
|
||||||
|
res_error(res, format_error_response("This server does not support tool calls. Start it with `--tool-calls`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);
|
||||||
|
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
|
|
||||||
|
@ -3423,11 +3446,27 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decide if we can enable tool calls
|
||||||
|
bool tool_call_support = false;
|
||||||
|
if (ctx_server.params.enable_tool_calls) {
|
||||||
|
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
|
||||||
|
tool_call_support = ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
|
||||||
|
if (tool_call_support) {
|
||||||
|
LOG_WARNING("Tool call is EXPERIMENTAL and maybe unstable. Use with your own risk", {});
|
||||||
|
} else {
|
||||||
|
LOG_ERROR("Tool call is not supported for this model. Please remove --tool-call or use with a supported model", {});
|
||||||
|
clean_up();
|
||||||
|
t.join();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
{
|
{
|
||||||
LOG_INFO("chat template", {
|
LOG_INFO("chat template", {
|
||||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||||
{"built_in", params.chat_template.empty()},
|
{"built_in", params.chat_template.empty()},
|
||||||
|
{"tool_call_support", tool_call_support},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
114
examples/server/tool-call.hpp
Normal file
114
examples/server/tool-call.hpp
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum llama_tool_format {
|
||||||
|
LLAMA_TOOL_FORMAT_NOT_SUPPORTED,
|
||||||
|
LLAMA_TOOL_FORMAT_HERMES_3,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llama_response_state {
|
||||||
|
LLAMA_RESPONSE_STATE_UNKNOWN,
|
||||||
|
LLAMA_RESPONSE_STATE_TEXT,
|
||||||
|
LLAMA_RESPONSE_STATE_TOOL_CALL,
|
||||||
|
};
|
||||||
|
|
||||||
|
// get the tool call format for the loaded model
|
||||||
|
// this function does linear search, so do not call it repeatedly
|
||||||
|
inline enum llama_tool_format get_tool_format(const struct llama_context * ctx) {
|
||||||
|
auto model = llama_get_model(ctx);
|
||||||
|
auto has_token = [&](std::string piece) {
|
||||||
|
for (int i = 0; i < llama_n_vocab(model); i++) {
|
||||||
|
const std::string token_str = llama_token_to_piece(ctx, i, true);
|
||||||
|
if (token_str == piece) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if (has_token("<|im_start|>") && has_token("<tool_call>")) {
|
||||||
|
return LLAMA_TOOL_FORMAT_HERMES_3;
|
||||||
|
}
|
||||||
|
return LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string format_chat_with_tool(enum llama_tool_format format, const std::vector<json> & messages, json tools) {
|
||||||
|
if (!tools.is_array()) {
|
||||||
|
throw std::runtime_error("tools must be an array");
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
auto chat = parse_chat_messages(messages);
|
||||||
|
if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
|
||||||
|
ss << "<|im_start|>system\n\n";
|
||||||
|
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>\n\n";
|
||||||
|
for (auto tool : tools) {
|
||||||
|
ss << tool.dump(1, '\t') << "\n\n";
|
||||||
|
}
|
||||||
|
ss << "</tools> Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n";
|
||||||
|
ss << "<tool_call>\n";
|
||||||
|
ss << "{\"arguments\": <args-dict>, \"name\": <function-name>}\n";
|
||||||
|
ss << "</tool_call><|im_end|>\n";
|
||||||
|
for (auto & message : chat) {
|
||||||
|
std::string role(message.role);
|
||||||
|
if (role == "system") {
|
||||||
|
continue; // for optimal performance, we skip user-defined system message
|
||||||
|
}
|
||||||
|
ss << "<|im_start|>" << role << "\n\n";
|
||||||
|
if (role == "tool") {
|
||||||
|
ss << "<tool_response>\n" << string_strip(message.content) << "\n</tool_response>\n";
|
||||||
|
} else {
|
||||||
|
ss << string_strip(message.content) << "<|im_end|>\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "<|im_start|>assistant\n\n";
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("tool_call is not supported by this model");
|
||||||
|
}
|
||||||
|
LOG_VERBOSE("format_chat_with_tool", {{"text", ss.str()}});
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the response is text or tool_call
|
||||||
|
// if it is tool_call, we may have to disable streaming, because we must parse the whole JSON response
|
||||||
|
inline enum llama_response_state check_response_state(enum llama_tool_format format, const std::string & generated_text) {
|
||||||
|
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
|
return LLAMA_RESPONSE_STATE_TEXT;
|
||||||
|
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3 && generated_text.rfind("<tool_call>", 0) == 0) {
|
||||||
|
return LLAMA_RESPONSE_STATE_TOOL_CALL;
|
||||||
|
}
|
||||||
|
return LLAMA_RESPONSE_STATE_TEXT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert model's response to OAI format
|
||||||
|
inline json parse_tool_response(enum llama_tool_format format, const std::string & generated_text) {
|
||||||
|
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
|
return json{};
|
||||||
|
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
|
||||||
|
std::string tmp(generated_text);
|
||||||
|
string_replace_all(tmp, "<tool_call>", "");
|
||||||
|
string_replace_all(tmp, "</tool_call>", "");
|
||||||
|
json tool = json::parse(tmp);
|
||||||
|
std::vector<json> tool_calls = {json{
|
||||||
|
{"id", tool.at("name")},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", tool.at("name")},
|
||||||
|
{"arguments", tool.at("arguments").dump()}, // OAI requires this to be JSON-stringified
|
||||||
|
}},
|
||||||
|
}};
|
||||||
|
return tool_calls;
|
||||||
|
}
|
||||||
|
return generated_text;
|
||||||
|
}
|
|
@ -116,10 +116,9 @@ static inline void server_log(const char * level, const char * function, int lin
|
||||||
// chat template utils
|
// chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// convert input chat messages from JSON to llama_chat_msg
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::vector<llama_chat_msg> parse_chat_messages(const std::vector<json> & messages) {
|
||||||
std::vector<llama_chat_msg> chat;
|
std::vector<llama_chat_msg> chat;
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
const auto & curr_msg = messages[i];
|
const auto & curr_msg = messages[i];
|
||||||
|
|
||||||
|
@ -144,7 +143,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
|
|
||||||
chat.push_back({role, content});
|
chat.push_back({role, content});
|
||||||
}
|
}
|
||||||
|
return chat;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||||
|
auto chat = parse_chat_messages(messages);
|
||||||
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
||||||
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
|
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
|
@ -356,7 +360,9 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
|
if (!body.contains("prompt")) {
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
|
@ -417,20 +423,31 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
bool has_tool_calls = result.contains("tool_calls");
|
||||||
|
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = has_tool_calls ? "tool_calls" : "stop";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json message = has_tool_calls
|
||||||
|
? json{
|
||||||
|
{"content", nullptr},
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", result.at("tool_calls")},
|
||||||
|
}
|
||||||
|
: json{
|
||||||
|
{"content", content},
|
||||||
|
{"role", "assistant"},
|
||||||
|
};
|
||||||
|
|
||||||
json choices =
|
json choices =
|
||||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
streaming ? json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json::object()}}})
|
{"delta", json::object()}}})
|
||||||
: json::array({json{{"finish_reason", finish_reason},
|
: json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"message", json{{"content", content},
|
{"message", message}}});
|
||||||
{"role", "assistant"}}}}});
|
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
|
@ -472,23 +489,54 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
bool stopped_limit = json_value(result, "stopped_limit", false);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
bool has_tool_calls = result.contains("tool_calls");
|
||||||
|
|
||||||
std::string finish_reason;
|
std::string finish_reason;
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = has_tool_calls ? "tool_calls" : "stop";
|
||||||
}
|
}
|
||||||
if (stopped_limit) {
|
if (stopped_limit) {
|
||||||
finish_reason = "length";
|
finish_reason = "length";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
auto wrap_choices = [&completion_id, &modelname](json choices) -> json {
|
||||||
|
return json{
|
||||||
|
{"choices", choices},
|
||||||
|
{"created", std::time(0)},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
json choices;
|
json choices;
|
||||||
|
json delta = has_tool_calls
|
||||||
|
? json{
|
||||||
|
{"content", nullptr},
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", result.at("tool_calls")},
|
||||||
|
}
|
||||||
|
: json{
|
||||||
|
{"content", content},
|
||||||
|
{"role", "assistant"},
|
||||||
|
};
|
||||||
|
|
||||||
if (!finish_reason.empty()) {
|
if (!finish_reason.empty()) {
|
||||||
choices = json::array({json{{"finish_reason", finish_reason},
|
choices = json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json::object()}}});
|
{"delta", json::object()}}});
|
||||||
|
if (has_tool_calls) {
|
||||||
|
// tool call must be send as two updates
|
||||||
|
json initial_ret = wrap_choices(json::array({
|
||||||
|
json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", delta},
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
json second_ret = wrap_choices(choices);
|
||||||
|
return std::vector<json>({initial_ret, second_ret});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (first) {
|
if (first) {
|
||||||
if (content.empty()) {
|
if (content.empty()) {
|
||||||
|
@ -497,28 +545,22 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
{"delta", json{{"role", "assistant"}}}}});
|
{"delta", json{{"role", "assistant"}}}}});
|
||||||
} else {
|
} else {
|
||||||
// We have to send this as two updates to conform to openai behavior
|
// We have to send this as two updates to conform to openai behavior
|
||||||
json initial_ret = json{{"choices", json::array({json{
|
json initial_ret = wrap_choices(json::array({
|
||||||
|
json{
|
||||||
{"finish_reason", nullptr},
|
{"finish_reason", nullptr},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json{
|
{"delta", json{
|
||||||
{"role", "assistant"}
|
{"role", "assistant"},
|
||||||
}}}})},
|
}},
|
||||||
{"created", t},
|
}
|
||||||
{"id", completion_id},
|
}));
|
||||||
{"model", modelname},
|
json second_ret = wrap_choices(json::array({
|
||||||
{"object", "chat.completion.chunk"}};
|
json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
json second_ret = json{
|
|
||||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json{
|
{"delta", delta},
|
||||||
{"content", content}}}
|
}
|
||||||
}})},
|
}));
|
||||||
{"created", t},
|
|
||||||
{"id", completion_id},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
return std::vector<json>({initial_ret, second_ret});
|
return std::vector<json>({initial_ret, second_ret});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -531,21 +573,12 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
choices = json::array({json{
|
choices = json::array({json{
|
||||||
{"finish_reason", nullptr},
|
{"finish_reason", nullptr},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta",
|
{"delta", delta},
|
||||||
json{
|
|
||||||
{"content", content},
|
|
||||||
}},
|
|
||||||
}});
|
}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json ret = json {
|
json ret = wrap_choices(choices);
|
||||||
{"choices", choices},
|
|
||||||
{"created", t},
|
|
||||||
{"id", completion_id},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}
|
|
||||||
};
|
|
||||||
if (!finish_reason.empty()) {
|
if (!finish_reason.empty()) {
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
|
|
|
@ -18565,7 +18565,7 @@ int32_t llama_model_meta_val_str(const struct llama_model * model, const char *
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
return snprintf(buf, buf_size, "%s", it->second.c_str());
|
return buf != NULL ? snprintf(buf, buf_size, "%s", it->second.c_str()) : it->second.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_model_meta_count(const struct llama_model * model) {
|
int32_t llama_model_meta_count(const struct llama_model * model) {
|
||||||
|
@ -20265,8 +20265,8 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
std::string & dest, bool add_ass) {
|
std::string & dest, bool add_ass) {
|
||||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
|
auto tmpl_contains = [&tmpl](std::string part) -> bool {
|
||||||
return tmpl.find(haystack) != std::string::npos;
|
return tmpl.find(part) != std::string::npos;
|
||||||
};
|
};
|
||||||
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
|
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
|
||||||
// chatml template
|
// chatml template
|
||||||
|
@ -20534,13 +20534,15 @@ int32_t llama_chat_apply_template(
|
||||||
if (tmpl == nullptr) {
|
if (tmpl == nullptr) {
|
||||||
GGML_ASSERT(model != nullptr);
|
GGML_ASSERT(model != nullptr);
|
||||||
// load template from model
|
// load template from model
|
||||||
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
|
||||||
std::string template_key = "tokenizer.chat_template";
|
std::string template_key = "tokenizer.chat_template";
|
||||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
// call with NULL buffer to get the total size of the string
|
||||||
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
// worst case: there is no information about template, we will use chatml by default
|
// worst case: there is no information about template, we will use chatml by default
|
||||||
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
|
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
|
||||||
} else {
|
} else {
|
||||||
|
std::vector<char> model_template(res, 0);
|
||||||
|
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||||
curr_tmpl = std::string(model_template.data(), model_template.size());
|
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue