Tool call support (generic + native for Llama, Functionary, Hermes, Mistral, Firefunction, DeepSeek) w/ lazy grammars (#9639)
--------- Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
27d135c970
commit
8b576b6c55
48 changed files with 3861 additions and 156 deletions
|
@ -113,10 +113,11 @@ struct slot_params {
|
|||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
|
@ -164,6 +165,8 @@ struct slot_params {
|
|||
{"n_probs", sampling.n_probs},
|
||||
{"min_keep", sampling.min_keep},
|
||||
{"grammar", sampling.grammar},
|
||||
// {"grammar_trigger_words", sampling.grammar_trigger_words},
|
||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
|
@ -325,12 +328,50 @@ struct server_task {
|
|||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
}
|
||||
|
||||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
||||
LOG_DBG("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
for (const auto & t : *grammar_triggers) {
|
||||
common_grammar_trigger trigger;
|
||||
trigger.word = t.at("word");
|
||||
trigger.at_start = t.at("at_start");
|
||||
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
}
|
||||
if (params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -382,22 +423,12 @@ struct server_task {
|
|||
}
|
||||
|
||||
{
|
||||
const auto & samplers = data.find("samplers");
|
||||
const auto samplers = data.find("samplers");
|
||||
if (samplers != data.end()) {
|
||||
if (samplers->is_array()) {
|
||||
std::vector<std::string> sampler_names;
|
||||
for (const auto & name : *samplers) {
|
||||
if (name.is_string()) {
|
||||
sampler_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
|
||||
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
|
||||
} else if (samplers->is_string()){
|
||||
std::string sampler_string;
|
||||
for (const auto & name : *samplers) {
|
||||
sampler_string += name;
|
||||
}
|
||||
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
|
||||
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
||||
}
|
||||
} else {
|
||||
params.sampling.samplers = defaults.sampling.samplers;
|
||||
|
@ -544,7 +575,7 @@ struct completion_token_output {
|
|||
struct server_task_result_cmpl_final : server_task_result {
|
||||
int index = 0;
|
||||
|
||||
std::string content;
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
bool stream;
|
||||
|
@ -566,10 +597,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
slot_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -663,18 +695,38 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg message;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
message = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
} else {
|
||||
message.content = content;
|
||||
}
|
||||
|
||||
json choice = json{
|
||||
json tool_calls;
|
||||
if (!message.tool_calls.empty()) {
|
||||
tool_calls = json::array();
|
||||
for (const auto & tc : message.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}},
|
||||
{"id", tc.id.empty() ? json() : json(tc.id)},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json {
|
||||
{"content", content},
|
||||
{"role", "assistant"}
|
||||
}
|
||||
}};
|
||||
{"content", message.content},
|
||||
{"tool_calls", tool_calls},
|
||||
{"role", "assistant"},
|
||||
}},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
choice["logprobs"] = json{
|
||||
|
@ -716,7 +768,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choice = json{
|
||||
json choice = json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
|
@ -1191,6 +1243,8 @@ struct server_slot {
|
|||
|
||||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
||||
|
@ -1815,17 +1869,16 @@ struct server_context {
|
|||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
templates.template_default->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
common_chat_params_init(*templates.template_default, inputs);
|
||||
if (templates.template_tool_use) {
|
||||
templates.template_tool_use->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
common_chat_params_init(*templates.template_tool_use, inputs);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
|
@ -2275,11 +2328,11 @@ struct server_context {
|
|||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.index;
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = slot.generated_tokens;
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->response_fields = slot.params.response_fields;
|
||||
res->response_fields = std::move(slot.params.response_fields);
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
|
@ -2290,12 +2343,12 @@ struct server_context {
|
|||
res->stop = slot.stop;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
|
@ -2773,6 +2826,11 @@ struct server_context {
|
|||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
|
||||
auto accept_special_token = [&](server_slot & slot, llama_token token) {
|
||||
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
|
||||
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
|
||||
};
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
|
@ -3136,7 +3194,7 @@ struct server_context {
|
|||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -3225,7 +3283,7 @@ struct server_context {
|
|||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
|
@ -3722,6 +3780,8 @@ int main(int argc, char ** argv) {
|
|||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model },
|
||||
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
||||
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
|
||||
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
||||
|
@ -3763,7 +3823,9 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||
const auto & prompt = data.at("prompt");
|
||||
LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
@ -3779,8 +3841,8 @@ int main(int argc, char ** argv) {
|
|||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(task);
|
||||
|
@ -3949,14 +4011,14 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
LOG_DBG("request: %s\n", req.body.c_str());
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
|
@ -3966,6 +4028,13 @@ int main(int argc, char ** argv) {
|
|||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
json models = {
|
||||
{"object", "list"},
|
||||
|
@ -4124,14 +4193,6 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, root);
|
||||
};
|
||||
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||
|
||||
res_ok(res, {{ "prompt", data.at("prompt") }});
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue