This commit is contained in:
bandoti 2025-02-09 22:24:28 -08:00 committed by GitHub
commit 05b028b9c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 207 additions and 32 deletions

View file

@ -2006,6 +2006,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::back_inserter(params.chat_template)); std::back_inserter(params.chat_template));
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"--tools"}, "JINJA_TOOLS",
"set to JSON array of tool definitions used for assistant function-calling (requires --jinja)",
[](common_params &params, const std::string & value) {
params.jinja_tools.tools(value);
}).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"--tool-choice"}, "JINJA_TOOL_CHOICE",
"set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")",
[](common_params &params, const std::string & value) {
params.jinja_tools.choice(value);
}).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg( add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY", {"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

View file

@ -7,9 +7,6 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "chat.hpp" #include "chat.hpp"
@ -1772,6 +1769,46 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
// Chat template utils // Chat template utils
// //
common_params_tools::common_params_tools(std::string tools, std::string choice) {
this->tools(tools);
this->choice(choice);
}
void common_params_tools::tools(std::string tools) {
if (tools.empty()) {
tools_.reset();
return;
}
try {
tools_ = std::make_shared<json>(json::parse(tools));
if (! tools_->is_array()) {
throw std::invalid_argument("tools must be a valid JSON array");
}
} catch (const json::exception & err) {
throw std::invalid_argument(err.what());
}
}
void common_params_tools::choice(std::string choice) {
try {
if (choice == "auto" || choice == "required" || choice == "none") {
tool_choice_ = std::move(choice);
} else {
auto choice_ptr = std::make_shared<json>(json::parse(choice));
tool_choice_ = choice_ptr;
if (! choice_ptr->is_object()) {
throw std::invalid_argument(
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
}
}
} catch (const json::exception & err) {
throw std::invalid_argument(err.what());
}
}
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) { if (use_jinja) {
try { try {
@ -1793,20 +1830,85 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
return res >= 0; return res >= 0;
} }
static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams)
{
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
auto & dst = *update_sparams->sparams;
auto vocab = update_sparams->vocab;
dst.grammar = src.grammar;
dst.grammar_lazy = src.grammar_lazy;
for (const auto & trigger : src.grammar_triggers) {
auto ids = common_tokenize(vocab, trigger.word, false, true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
dst.grammar_trigger_tokens.push_back(ids[0]);
dst.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
dst.grammar_trigger_words.push_back(trigger);
}
for (const auto & preserved : src.preserved_tokens) {
auto ids = common_tokenize(vocab, preserved, false, true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
dst.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model
// with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n",
preserved.c_str());
}
}
}
std::string common_chat_apply_template( std::string common_chat_apply_template(
const common_chat_template & tmpl, const common_chat_templates & tmpl,
const std::vector<common_chat_msg> & msgs, const std::vector<common_chat_msg> & msgs,
bool add_ass, bool add_ass,
bool use_jinja) { bool use_jinja,
const common_params_tools & tools,
common_chat_sampling_updater * update_sparams)
{
const auto & tmpl_selected =
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
if (use_jinja) { if (use_jinja) {
common_chat_inputs inputs;
auto messages = json::array(); auto messages = json::array();
for (const auto & msg : msgs) { for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}}); messages.push_back({{"role", msg.role}, {"content", msg.content}});
} }
common_chat_inputs inputs;
if (tools.tools() != nullptr) {
inputs.tools = *tools.tools();
}
auto choice = tools.choice();
if (std::holds_alternative<std::string>(choice)) {
inputs.tool_choice = std::get<std::string>(choice);
} else {
auto choice_ptr = std::get<common_params_tools::json_ptr>(choice);
if (choice_ptr != nullptr) {
inputs.tool_choice = *choice_ptr;
}
}
inputs.messages = messages; inputs.messages = messages;
inputs.add_generation_prompt = add_ass; inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl, inputs).prompt; auto chat_params = common_chat_params_init(tmpl_selected, inputs);
if (update_sparams) {
copy_chat_params(chat_params, update_sparams);
}
return chat_params.prompt;
} }
int alloc_size = 0; int alloc_size = 0;
@ -1819,7 +1921,7 @@ std::string common_chat_apply_template(
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported // error: chat template is not supported
if (res < 0) { if (res < 0) {
@ -1831,7 +1933,7 @@ std::string common_chat_apply_template(
// if it turns out that our buffer is too small, we resize it // if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) { if ((size_t) res > buf.size()) {
buf.resize(res); buf.resize(res);
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }
std::string formatted_chat(buf.data(), res); std::string formatted_chat(buf.data(), res);
@ -1839,13 +1941,18 @@ std::string common_chat_apply_template(
} }
std::string common_chat_format_single( std::string common_chat_format_single(
const common_chat_template & tmpl, const common_chat_templates & tmpl,
const std::vector<common_chat_msg> & past_msg, const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg, const common_chat_msg & new_msg,
bool add_ass, bool add_ass,
bool use_jinja) { bool use_jinja,
const common_params_tools & tools,
common_chat_sampling_updater * update_sparams)
{
std::ostringstream ss; std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); auto fmt_past_msg = past_msg.empty() ? ""
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams);
std::vector<common_chat_msg> chat_new(past_msg); std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version // if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1853,13 +1960,13 @@ std::string common_chat_format_single(
}; };
// format chat with new_msg // format chat with new_msg
chat_new.push_back(new_msg); chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams);
// get the diff part // get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str(); return ss.str();
} }
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = { std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant", {}}, {"system", "You are a helpful assistant", {}},
{"user", "Hello", {}}, {"user", "Hello", {}},
@ -2195,4 +2302,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result; return result;
} }

View file

@ -8,6 +8,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <variant>
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#ifdef _WIN32 #ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\' #define DIRECTORY_SEPARATOR '\\'
@ -202,6 +206,31 @@ struct common_params_vocoder {
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
}; };
class common_params_tools {
public:
using json = nlohmann::ordered_json;
using json_ptr = std::shared_ptr<json>;
using tool_choice_t = std::variant<std::string, json_ptr>;
common_params_tools(std::string tools = "",
std::string choice = "auto");
common_params_tools(const common_params_tools & other) = default;
common_params_tools(common_params_tools && other) noexcept = default;
common_params_tools & operator=(const common_params_tools & other) = default;
common_params_tools & operator=(common_params_tools && other) noexcept = default;
void tools(std::string tools);
const json * tools() const { return tools_.get(); }
void choice(std::string choice);
const tool_choice_t & choice() const { return tool_choice_; }
private:
json_ptr tools_;
tool_choice_t tool_choice_;
};
struct common_params { struct common_params {
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size int32_t n_ctx = 4096; // context size
@ -346,6 +375,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT bool use_jinja = false; // NOLINT
bool enable_chat_template = true; bool enable_chat_template = true;
common_params_tools jinja_tools;
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
@ -641,26 +671,35 @@ struct common_chat_templates {
std::unique_ptr<common_chat_template> template_tool_use; std::unique_ptr<common_chat_template> template_tool_use;
}; };
struct common_chat_sampling_updater {
common_params_sampling * sparams;
const llama_vocab * vocab;
};
// CPP wrapper for llama_chat_apply_template // CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml // If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error // If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template( std::string common_chat_apply_template(
const common_chat_template & tmpl, const common_chat_templates & tmpl,
const std::vector<common_chat_msg> & chat, const std::vector<common_chat_msg> & chat,
bool add_ass, bool add_ass,
bool use_jinja); bool use_jinja,
const common_params_tools & tools = common_params_tools(),
common_chat_sampling_updater * update_sparams = nullptr);
// Format single message, while taking into account the position of that message in chat history // Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single( std::string common_chat_format_single(
const common_chat_template & tmpl, const common_chat_templates & tmpl,
const std::vector<common_chat_msg> & past_msg, const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg, const common_chat_msg & new_msg,
bool add_ass, bool add_ass,
bool use_jinja); bool use_jinja,
const common_params_tools & tools = common_params_tools(),
common_chat_sampling_updater * update_sparams = nullptr);
// Returns an example of formatted chat // Returns an example of formatted chat
std::string common_chat_format_example( std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja); const common_chat_templates & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);

View file

@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode // print chat template example in conversation mode
if (params.conversation_mode) { if (params.conversation_mode) {
if (params.enable_chat_template) { if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str());
} else { } else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
} }
@ -263,20 +263,31 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab](
const std::string & role, const std::string & content,
const common_params_tools & tools = common_params_tools())
{
bool add_ass = (role == "user");
common_chat_msg new_msg{role, content, {}}; common_chat_msg new_msg{role, content, {}};
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
common_chat_sampling_updater updater{&sparams, vocab};
auto formatted =
common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja,
tools, &updater);
chat_msgs.push_back({role, content, {}}); chat_msgs.push_back({role, content, {}});
LOG_DBG("formatted: '%s'\n", formatted.c_str()); LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted; return formatted;
}; };
{ {
auto prompt = (params.conversation_mode && params.enable_chat_template) std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt);
// format the system prompt in conversation mode (fallback to default if empty) bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template);
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) auto prompt = use_conversation_prompt ?
// otherwise use the prompt as is chat_add_and_format("system", system_prompt, params.jinja_tools)
: params.prompt; : params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG_DBG("tokenize the prompt\n"); LOG_DBG("tokenize the prompt\n");
embd_inp = common_tokenize(ctx, prompt, true, true); embd_inp = common_tokenize(ctx, prompt, true, true);

View file

@ -4479,7 +4479,7 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used // print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
ctx_server.chat_templates.template_default->source().c_str(), ctx_server.chat_templates.template_default->source().c_str(),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task); ctx_server.process_single_task(task);

View file

@ -348,7 +348,7 @@ static llama_tokens format_infill(
} }
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) { inline std::string format_chat(const common_chat_templates & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat; std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse(
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
} }
} else { } else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages")); llama_params["prompt"] = format_chat(chat_templates, body.at("messages"));
} }
// Handle "n" field // Handle "n" field

View file

@ -339,7 +339,8 @@ int main(void) {
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
auto fmt_sys = [&](std::string tmpl_str) { auto fmt_sys = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", ""); common_chat_templates tmpl;
tmpl.template_default.reset(new common_chat_template(tmpl_str, "", ""));
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
@ -366,7 +367,8 @@ int main(void) {
common_chat_msg new_msg{"user", "How are you", {}}; common_chat_msg new_msg{"user", "How are you", {}};
auto fmt_single = [&](std::string tmpl_str) { auto fmt_single = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", ""); common_chat_templates tmpl;
tmpl.template_default.reset(new common_chat_template(tmpl_str, "", ""));
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");