This commit is contained in:
Xuan Son Nguyen 2024-09-02 16:59:45 +02:00 committed by GitHub
commit 843d97b1c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 281 additions and 63 deletions

View file

@ -428,6 +428,7 @@ void gpt_params_parse_from_env(gpt_params & params) {
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
get_env("LLAMA_ARG_HOST", params.hostname);
get_env("LLAMA_ARG_PORT", params.port);
get_env("LLAMA_ARG_TOOL_CALLS", params.enable_tool_calls);
}
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
@ -1046,6 +1047,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.lora_init_without_apply = true;
return true;
}
if (arg == "--tool-call" || arg == "--tool-calls") {
params.enable_tool_calls = true;
return true;
}
if (arg == "--control-vector") {
CHECK_ARG
params.control_vectors.push_back({ 1.0f, argv[i], });
@ -2036,6 +2041,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
options.push_back({ "server", " --tool-call(s)", "enable OAI tool calls for chat completion endpoint (default: %s)", params.enable_tool_calls ? "enabled" : "disabled"});
#ifndef LOG_DISABLE_LOGS
options.push_back({ "logging" });
@ -2253,6 +2259,10 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
return true;
}
bool string_contains(std::string haystack, std::string needle) {
return haystack.find(needle) != std::string::npos;
}
//
// Filesystem utils
//
@ -3186,6 +3196,19 @@ std::string llama_chat_format_example(const struct llama_model * model,
return llama_chat_apply_template(model, tmpl, msgs, true);
}
std::string llama_get_chat_template(const struct llama_model * model) {
std::string template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 0) {
return "";
} else {
std::vector<char> model_template(res, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size());
}
}
//
// KV cache utils
//

View file

@ -221,6 +221,7 @@ struct gpt_params {
std::string chat_template = "";
std::string system_prompt = "";
bool enable_chat_template = true;
bool enable_tool_calls = false;
std::vector<std::string> api_keys;
@ -320,6 +321,8 @@ static std::vector<T> string_split(const std::string & str, char delim) {
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
bool string_contains(std::string haystack, std::string needle);
//
// Filesystem utils
//
@ -428,6 +431,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
std::string llama_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
// Returns the chat template stored inside the model
// (empty string if model does not have built-in chat template)
std::string llama_get_chat_template(const struct llama_model * model);
//
// KV cache utils
//

View file

@ -4,6 +4,7 @@
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
#include "tool-call.hpp"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
@ -157,6 +158,7 @@ struct server_slot {
std::string generated_text;
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
bool infill = false;
bool embedding = false;
@ -207,6 +209,7 @@ struct server_slot {
infill = false;
ga_i = 0;
n_past_se = 0;
response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
generated_token_probs.clear();
}
@ -625,6 +628,7 @@ struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
gpt_params params;
@ -1217,7 +1221,13 @@ struct server_context {
break;
}
if (!incomplete) {
if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) {
slot.response_state = check_response_state(tool_format, slot.generated_text);
}
// if response is tool call, we cannot stream it
// instead, we wait for the full response, then extract JSON
if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
@ -1247,9 +1257,7 @@ struct server_context {
if (slot.params.stream) {
send_partial_response(slot, result);
}
}
if (incomplete) {
} else {
slot.has_next_token = true;
}
@ -1396,6 +1404,10 @@ struct server_context {
{"multimodal", false}
};
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send);
}
if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
@ -1444,6 +1456,10 @@ struct server_context {
{"timings", slot.get_formated_timings()}
};
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text);
}
if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) {
};
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
std::string template_key = "tokenizer.chat_template", curr_tmpl;
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
}
}
std::string chat_tmpl = ctx_server.params.chat_template.empty()
? llama_get_chat_template(ctx_server.model)
: ctx_server.params.chat_template;
json data = {
{ "system_prompt", ctx_server.system_prompt.c_str() },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel },
{ "chat_template", curr_tmpl.c_str() }
{ "chat_template", chat_tmpl },
};
res.set_content(data.dump(), MIMETYPE_JSON);
@ -3056,7 +3067,19 @@ int main(int argc, char ** argv) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
json body = json::parse(req.body);
if (body.contains("tools")) {
if (ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
body.erase(body.find("tools"));
} else {
res_error(res, format_error_response("This server does not support tool calls. Start it with `--tool-calls`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
}
json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3423,11 +3446,27 @@ int main(int argc, char ** argv) {
}
}
// decide if we can enable tool calls
bool tool_call_support = false;
if (ctx_server.params.enable_tool_calls) {
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
tool_call_support = ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
if (tool_call_support) {
LOG_WARNING("Tool call is EXPERIMENTAL and maybe unstable. Use with your own risk", {});
} else {
LOG_ERROR("Tool call is not supported for this model. Please remove --tool-call or use with a supported model", {});
clean_up();
t.join();
return 1;
}
}
// print sample chat example to make it clear which template is used
{
LOG_INFO("chat template", {
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
{"tool_call_support", tool_call_support},
});
}

View file

@ -0,0 +1,114 @@
#pragma once
#include "llama.h"
#include "common.h"
#include "utils.hpp"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include <string>
#include <vector>
#include <sstream>
using json = nlohmann::ordered_json;
enum llama_tool_format {
LLAMA_TOOL_FORMAT_NOT_SUPPORTED,
LLAMA_TOOL_FORMAT_HERMES_3,
};
enum llama_response_state {
LLAMA_RESPONSE_STATE_UNKNOWN,
LLAMA_RESPONSE_STATE_TEXT,
LLAMA_RESPONSE_STATE_TOOL_CALL,
};
// get the tool call format for the loaded model
// this function does linear search, so do not call it repeatedly
inline enum llama_tool_format get_tool_format(const struct llama_context * ctx) {
auto model = llama_get_model(ctx);
auto has_token = [&](std::string piece) {
for (int i = 0; i < llama_n_vocab(model); i++) {
const std::string token_str = llama_token_to_piece(ctx, i, true);
if (token_str == piece) {
return true;
}
}
return false;
};
if (has_token("<|im_start|>") && has_token("<tool_call>")) {
return LLAMA_TOOL_FORMAT_HERMES_3;
}
return LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
}
inline std::string format_chat_with_tool(enum llama_tool_format format, const std::vector<json> & messages, json tools) {
if (!tools.is_array()) {
throw std::runtime_error("tools must be an array");
}
std::stringstream ss;
auto chat = parse_chat_messages(messages);
if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
ss << "<|im_start|>system\n\n";
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>\n\n";
for (auto tool : tools) {
ss << tool.dump(1, '\t') << "\n\n";
}
ss << "</tools> Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n";
ss << "<tool_call>\n";
ss << "{\"arguments\": <args-dict>, \"name\": <function-name>}\n";
ss << "</tool_call><|im_end|>\n";
for (auto & message : chat) {
std::string role(message.role);
if (role == "system") {
continue; // for optimal performance, we skip user-defined system message
}
ss << "<|im_start|>" << role << "\n\n";
if (role == "tool") {
ss << "<tool_response>\n" << string_strip(message.content) << "\n</tool_response>\n";
} else {
ss << string_strip(message.content) << "<|im_end|>\n";
}
}
ss << "<|im_start|>assistant\n\n";
} else {
throw std::runtime_error("tool_call is not supported by this model");
}
LOG_VERBOSE("format_chat_with_tool", {{"text", ss.str()}});
return ss.str();
}
// check if the response is text or tool_call
// if it is tool_call, we may have to disable streaming, because we must parse the whole JSON response
inline enum llama_response_state check_response_state(enum llama_tool_format format, const std::string & generated_text) {
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
return LLAMA_RESPONSE_STATE_TEXT;
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3 && generated_text.rfind("<tool_call>", 0) == 0) {
return LLAMA_RESPONSE_STATE_TOOL_CALL;
}
return LLAMA_RESPONSE_STATE_TEXT;
}
// convert model's response to OAI format
inline json parse_tool_response(enum llama_tool_format format, const std::string & generated_text) {
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
return json{};
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
std::string tmp(generated_text);
string_replace_all(tmp, "<tool_call>", "");
string_replace_all(tmp, "</tool_call>", "");
json tool = json::parse(tmp);
std::vector<json> tool_calls = {json{
{"id", tool.at("name")},
{"type", "function"},
{"function", {
{"name", tool.at("name")},
{"arguments", tool.at("arguments").dump()}, // OAI requires this to be JSON-stringified
}},
}};
return tool_calls;
}
return generated_text;
}

View file

@ -116,10 +116,9 @@ static inline void server_log(const char * level, const char * function, int lin
// chat template utils
//
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
// convert input chat messages from JSON to llama_chat_msg
inline std::vector<llama_chat_msg> parse_chat_messages(const std::vector<json> & messages) {
std::vector<llama_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
@ -144,7 +143,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
chat.push_back({role, content});
}
return chat;
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
auto chat = parse_chat_messages(messages);
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
return formatted_chat;
@ -356,7 +360,9 @@ static json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true;
// Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
if (!body.contains("prompt")) {
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@ -417,20 +423,31 @@ static json format_final_response_oaicompat(const json & request, json result, c
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
bool has_tool_calls = result.contains("tool_calls");
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
finish_reason = has_tool_calls ? "tool_calls" : "stop";
}
json message = has_tool_calls
? json{
{"content", nullptr},
{"role", "assistant"},
{"tool_calls", result.at("tool_calls")},
}
: json{
{"content", content},
{"role", "assistant"},
};
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
{"message", message}}});
std::time_t t = std::time(0);
@ -472,23 +489,54 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
bool has_tool_calls = result.contains("tool_calls");
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
finish_reason = has_tool_calls ? "tool_calls" : "stop";
}
if (stopped_limit) {
finish_reason = "length";
}
std::time_t t = std::time(0);
auto wrap_choices = [&completion_id, &modelname](json choices) -> json {
return json{
{"choices", choices},
{"created", std::time(0)},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
};
json choices;
json delta = has_tool_calls
? json{
{"content", nullptr},
{"role", "assistant"},
{"tool_calls", result.at("tool_calls")},
}
: json{
{"content", content},
{"role", "assistant"},
};
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
if (has_tool_calls) {
// tool call must be send as two updates
json initial_ret = wrap_choices(json::array({
json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", delta},
}
}));
json second_ret = wrap_choices(choices);
return std::vector<json>({initial_ret, second_ret});
}
} else {
if (first) {
if (content.empty()) {
@ -497,28 +545,22 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json initial_ret = wrap_choices(json::array({
json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"},
}},
}
}));
json second_ret = wrap_choices(json::array({
json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", delta},
}
}));
return std::vector<json>({initial_ret, second_ret});
}
} else {
@ -531,21 +573,12 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
{"delta", delta},
}});
}
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
json ret = wrap_choices(choices);
if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);

View file

@ -18565,7 +18565,7 @@ int32_t llama_model_meta_val_str(const struct llama_model * model, const char *
}
return -1;
}
return snprintf(buf, buf_size, "%s", it->second.c_str());
return buf != NULL ? snprintf(buf, buf_size, "%s", it->second.c_str()) : it->second.size();
}
int32_t llama_model_meta_count(const struct llama_model * model) {
@ -20265,8 +20265,8 @@ static int32_t llama_chat_apply_template_internal(
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
return tmpl.find(haystack) != std::string::npos;
auto tmpl_contains = [&tmpl](std::string part) -> bool {
return tmpl.find(part) != std::string::npos;
};
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
// chatml template
@ -20534,13 +20534,15 @@ int32_t llama_chat_apply_template(
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
} else {
std::vector<char> model_template(res, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
curr_tmpl = std::string(model_template.data(), model_template.size());
}
}