server : add Hermes-3 tool call support
This commit is contained in:
parent
7ea8d80d53
commit
7e017cfbc8
6 changed files with 240 additions and 37 deletions
|
@ -2253,6 +2253,10 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool string_contains(std::string haystack, std::string needle) {
|
||||||
|
return haystack.find(needle) != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
@ -3186,6 +3190,19 @@ std::string llama_chat_format_example(const struct llama_model * model,
|
||||||
return llama_chat_apply_template(model, tmpl, msgs, true);
|
return llama_chat_apply_template(model, tmpl, msgs, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_get_chat_template(const struct llama_model * model) {
|
||||||
|
std::string template_key = "tokenizer.chat_template";
|
||||||
|
// call with NULL buffer to get the total size of the string
|
||||||
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
||||||
|
if (res < 0) {
|
||||||
|
return "";
|
||||||
|
} else {
|
||||||
|
std::vector<char> model_template(res, 0);
|
||||||
|
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||||
|
return std::string(model_template.data(), model_template.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -320,6 +320,8 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
|
bool string_contains(std::string haystack, std::string needle);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
@ -428,6 +430,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
std::string llama_chat_format_example(const struct llama_model * model,
|
std::string llama_chat_format_example(const struct llama_model * model,
|
||||||
const std::string & tmpl);
|
const std::string & tmpl);
|
||||||
|
|
||||||
|
// Returns the chat template stored inside the model
|
||||||
|
// (empty string if model does not have built-in chat template)
|
||||||
|
std::string llama_get_chat_template(const struct llama_model * model);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
#include "grammar-parser.h"
|
||||||
|
#include "tool-call.hpp"
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// crash the server in debug mode, otherwise send an http 500 error
|
// crash the server in debug mode, otherwise send an http 500 error
|
||||||
|
@ -157,6 +158,7 @@ struct server_slot {
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
|
||||||
|
|
||||||
bool infill = false;
|
bool infill = false;
|
||||||
bool embedding = false;
|
bool embedding = false;
|
||||||
|
@ -207,6 +209,7 @@ struct server_slot {
|
||||||
infill = false;
|
infill = false;
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
n_past_se = 0;
|
n_past_se = 0;
|
||||||
|
response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
@ -625,6 +628,7 @@ struct server_context {
|
||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||||
|
llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
|
||||||
|
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
|
@ -1217,7 +1221,13 @@ struct server_context {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!incomplete) {
|
if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) {
|
||||||
|
slot.response_state = check_response_state(tool_format, slot.generated_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// if response is tool call, we cannot stream it
|
||||||
|
// instead, we wait for the full response, then extract JSON
|
||||||
|
if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
|
@ -1247,9 +1257,7 @@ struct server_context {
|
||||||
if (slot.params.stream) {
|
if (slot.params.stream) {
|
||||||
send_partial_response(slot, result);
|
send_partial_response(slot, result);
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
|
||||||
if (incomplete) {
|
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1396,6 +1404,10 @@ struct server_context {
|
||||||
{"multimodal", false}
|
{"multimodal", false}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
|
||||||
|
res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send);
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
||||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||||
|
@ -1444,6 +1456,10 @@ struct server_context {
|
||||||
{"timings", slot.get_formated_timings()}
|
{"timings", slot.get_formated_timings()}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
|
||||||
|
res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text);
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
std::vector<completion_token_output> probs;
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
|
@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||||
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
std::string chat_tmpl = ctx_server.params.chat_template.empty()
|
||||||
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
? llama_get_chat_template(ctx_server.model)
|
||||||
if (tlen > 0) {
|
: ctx_server.params.chat_template;
|
||||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
|
||||||
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
|
||||||
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
json data = {
|
json data = {
|
||||||
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
||||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||||
{ "total_slots", ctx_server.params.n_parallel },
|
{ "total_slots", ctx_server.params.n_parallel },
|
||||||
{ "chat_template", curr_tmpl.c_str() }
|
{ "chat_template", chat_tmpl },
|
||||||
};
|
};
|
||||||
|
|
||||||
res.set_content(data.dump(), MIMETYPE_JSON);
|
res.set_content(data.dump(), MIMETYPE_JSON);
|
||||||
|
@ -3056,7 +3067,13 @@ int main(int argc, char ** argv) {
|
||||||
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json body = json::parse(req.body);
|
||||||
|
|
||||||
|
if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
|
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
|
||||||
|
}
|
||||||
|
|
||||||
|
json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);
|
||||||
|
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
|
|
||||||
|
@ -3423,11 +3440,15 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decide if we can enable tool calls
|
||||||
|
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
{
|
{
|
||||||
LOG_INFO("chat template", {
|
LOG_INFO("chat template", {
|
||||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||||
{"built_in", params.chat_template.empty()},
|
{"built_in", params.chat_template.empty()},
|
||||||
|
{"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
114
examples/server/tool-call.hpp
Normal file
114
examples/server/tool-call.hpp
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum llama_tool_format {
|
||||||
|
LLAMA_TOOL_FORMAT_NOT_SUPPORTED,
|
||||||
|
LLAMA_TOOL_FORMAT_HERMES_3,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llama_response_state {
|
||||||
|
LLAMA_RESPONSE_STATE_UNKNOWN,
|
||||||
|
LLAMA_RESPONSE_STATE_TEXT,
|
||||||
|
LLAMA_RESPONSE_STATE_TOOL_CALL,
|
||||||
|
};
|
||||||
|
|
||||||
|
// get the tool call format for the loaded model
|
||||||
|
// this function does linear search, so do not call it repeatedly
|
||||||
|
inline enum llama_tool_format get_tool_format(const struct llama_context * ctx) {
|
||||||
|
auto model = llama_get_model(ctx);
|
||||||
|
auto has_token = [&](std::string piece) {
|
||||||
|
for (int i = 0; i < llama_n_vocab(model); i++) {
|
||||||
|
const std::string token_str = llama_token_to_piece(ctx, i, true);
|
||||||
|
if (token_str == piece) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if (has_token("<|im_start|>") && has_token("<tool_call>")) {
|
||||||
|
return LLAMA_TOOL_FORMAT_HERMES_3;
|
||||||
|
}
|
||||||
|
return LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string format_chat_with_tool(enum llama_tool_format format, const std::vector<json> & messages, json tools) {
|
||||||
|
if (!tools.is_array()) {
|
||||||
|
throw std::runtime_error("tools must be an array");
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
auto chat = parse_chat_messages(messages);
|
||||||
|
if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
|
||||||
|
ss << "<|im_start|>system\n\n";
|
||||||
|
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>\n\n";
|
||||||
|
for (auto tool : tools) {
|
||||||
|
ss << tool.dump(1, '\t') << "\n\n";
|
||||||
|
}
|
||||||
|
ss << "</tools> Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n";
|
||||||
|
ss << "<tool_call>\n";
|
||||||
|
ss << "{\"arguments\": <args-dict>, \"name\": <function-name>}\n";
|
||||||
|
ss << "</tool_call><|im_end|>\n";
|
||||||
|
for (auto & message : chat) {
|
||||||
|
std::string role(message.role);
|
||||||
|
if (role == "system") {
|
||||||
|
continue; // for optimal performance, we skip user-defined system message
|
||||||
|
}
|
||||||
|
ss << "<|im_start|>" << role << "\n\n";
|
||||||
|
if (role == "tool") {
|
||||||
|
ss << "<tool_response>\n" << string_strip(message.content) << "\n</tool_response>\n";
|
||||||
|
} else {
|
||||||
|
ss << string_strip(message.content) << "<|im_end|>\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "<|im_start|>assistant\n\n";
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("tool_call is not supported by this model");
|
||||||
|
}
|
||||||
|
LOG_VERBOSE("format_chat_with_tool", {{"text", ss.str()}});
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the response is text or tool_call
|
||||||
|
// if it is tool_call, we may have to disable streaming, because we must parse the whole JSON response
|
||||||
|
inline enum llama_response_state check_response_state(enum llama_tool_format format, const std::string & generated_text) {
|
||||||
|
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
|
return LLAMA_RESPONSE_STATE_TEXT;
|
||||||
|
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3 && generated_text.rfind("<tool_call>", 0) == 0) {
|
||||||
|
return LLAMA_RESPONSE_STATE_TOOL_CALL;
|
||||||
|
}
|
||||||
|
return LLAMA_RESPONSE_STATE_TEXT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert model's response to OAI format
|
||||||
|
inline json parse_tool_response(enum llama_tool_format format, const std::string & generated_text) {
|
||||||
|
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
|
return json{};
|
||||||
|
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
|
||||||
|
std::string tmp(generated_text);
|
||||||
|
string_replace_all(tmp, "<tool_call>", "");
|
||||||
|
string_replace_all(tmp, "</tool_call>", "");
|
||||||
|
json tool = json::parse(tmp);
|
||||||
|
std::vector<json> tool_calls = {json{
|
||||||
|
{"id", tool.at("name")},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", tool.at("name")},
|
||||||
|
{"arguments", tool.at("arguments").dump()}, // OAI requires this to be JSON-stringified
|
||||||
|
}},
|
||||||
|
}};
|
||||||
|
return tool_calls;
|
||||||
|
}
|
||||||
|
return generated_text;
|
||||||
|
}
|
|
@ -116,10 +116,9 @@ static inline void server_log(const char * level, const char * function, int lin
|
||||||
// chat template utils
|
// chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// convert input chat messages from JSON to llama_chat_msg
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::vector<llama_chat_msg> parse_chat_messages(const std::vector<json> & messages) {
|
||||||
std::vector<llama_chat_msg> chat;
|
std::vector<llama_chat_msg> chat;
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
const auto & curr_msg = messages[i];
|
const auto & curr_msg = messages[i];
|
||||||
|
|
||||||
|
@ -144,7 +143,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
|
|
||||||
chat.push_back({role, content});
|
chat.push_back({role, content});
|
||||||
}
|
}
|
||||||
|
return chat;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||||
|
auto chat = parse_chat_messages(messages);
|
||||||
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
||||||
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
|
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
|
@ -356,7 +360,9 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
if (!body.contains("prompt")) {
|
||||||
|
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
|
@ -391,7 +397,7 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
|
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
||||||
for (auto & param : unsupported_params) {
|
for (auto & param : unsupported_params) {
|
||||||
if (body.contains(param)) {
|
if (body.contains(param)) {
|
||||||
throw std::runtime_error("Unsupported param: " + param);
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
@ -417,20 +423,31 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
bool has_tool_calls = result.contains("tool_calls");
|
||||||
|
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = has_tool_calls ? "tool_calls" : "stop";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json message = has_tool_calls
|
||||||
|
? json{
|
||||||
|
{"content", nullptr},
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", result.at("tool_calls")},
|
||||||
|
}
|
||||||
|
: json{
|
||||||
|
{"content", content},
|
||||||
|
{"role", "assistant"},
|
||||||
|
};
|
||||||
|
|
||||||
json choices =
|
json choices =
|
||||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
streaming ? json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json::object()}}})
|
{"delta", json::object()}}})
|
||||||
: json::array({json{{"finish_reason", finish_reason},
|
: json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"message", json{{"content", content},
|
{"message", message}}});
|
||||||
{"role", "assistant"}}}}});
|
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
|
@ -472,10 +489,11 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
bool stopped_limit = json_value(result, "stopped_limit", false);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
bool has_tool_calls = result.contains("tool_calls");
|
||||||
|
|
||||||
std::string finish_reason;
|
std::string finish_reason;
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = has_tool_calls ? "tool_calls" : "stop";
|
||||||
}
|
}
|
||||||
if (stopped_limit) {
|
if (stopped_limit) {
|
||||||
finish_reason = "length";
|
finish_reason = "length";
|
||||||
|
@ -484,11 +502,41 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
json choices;
|
json choices;
|
||||||
|
json delta = has_tool_calls
|
||||||
|
? json{
|
||||||
|
{"content", nullptr},
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", result.at("tool_calls")},
|
||||||
|
}
|
||||||
|
: json{
|
||||||
|
{"content", content},
|
||||||
|
{"role", "assistant"},
|
||||||
|
};
|
||||||
|
|
||||||
if (!finish_reason.empty()) {
|
if (!finish_reason.empty()) {
|
||||||
choices = json::array({json{{"finish_reason", finish_reason},
|
choices = json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json::object()}}});
|
{"delta", json::object()}}});
|
||||||
|
if (has_tool_calls) {
|
||||||
|
// tool call must be send as two updates
|
||||||
|
json initial_ret = json{{"choices", json::array({json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", delta}}})},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}};
|
||||||
|
|
||||||
|
json second_ret = json{
|
||||||
|
{"choices", choices},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}};
|
||||||
|
|
||||||
|
return std::vector<json>({initial_ret, second_ret});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (first) {
|
if (first) {
|
||||||
if (content.empty()) {
|
if (content.empty()) {
|
||||||
|
@ -511,9 +559,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
json second_ret = json{
|
json second_ret = json{
|
||||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json{
|
{"delta", delta}}})},
|
||||||
{"content", content}}}
|
|
||||||
}})},
|
|
||||||
{"created", t},
|
{"created", t},
|
||||||
{"id", completion_id},
|
{"id", completion_id},
|
||||||
{"model", modelname},
|
{"model", modelname},
|
||||||
|
@ -531,10 +577,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
||||||
choices = json::array({json{
|
choices = json::array({json{
|
||||||
{"finish_reason", nullptr},
|
{"finish_reason", nullptr},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta",
|
{"delta", delta},
|
||||||
json{
|
|
||||||
{"content", content},
|
|
||||||
}},
|
|
||||||
}});
|
}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18058,7 +18058,7 @@ int32_t llama_model_meta_val_str(const struct llama_model * model, const char *
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
return snprintf(buf, buf_size, "%s", it->second.c_str());
|
return buf != NULL ? snprintf(buf, buf_size, "%s", it->second.c_str()) : it->second.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_model_meta_count(const struct llama_model * model) {
|
int32_t llama_model_meta_count(const struct llama_model * model) {
|
||||||
|
@ -19757,8 +19757,8 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
std::string & dest, bool add_ass) {
|
std::string & dest, bool add_ass) {
|
||||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
|
auto tmpl_contains = [&tmpl](std::string part) -> bool {
|
||||||
return tmpl.find(haystack) != std::string::npos;
|
return tmpl.find(part) != std::string::npos;
|
||||||
};
|
};
|
||||||
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
|
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
|
||||||
// chatml template
|
// chatml template
|
||||||
|
@ -20026,13 +20026,15 @@ int32_t llama_chat_apply_template(
|
||||||
if (tmpl == nullptr) {
|
if (tmpl == nullptr) {
|
||||||
GGML_ASSERT(model != nullptr);
|
GGML_ASSERT(model != nullptr);
|
||||||
// load template from model
|
// load template from model
|
||||||
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
|
||||||
std::string template_key = "tokenizer.chat_template";
|
std::string template_key = "tokenizer.chat_template";
|
||||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
// call with NULL buffer to get the total size of the string
|
||||||
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
// worst case: there is no information about template, we will use chatml by default
|
// worst case: there is no information about template, we will use chatml by default
|
||||||
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
|
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
|
||||||
} else {
|
} else {
|
||||||
|
std::vector<char> model_template(res, 0);
|
||||||
|
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||||
curr_tmpl = std::string(model_template.data(), model_template.size());
|
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue