Finish renaming of chat inputs vs. params [skip ci]

This commit is contained in:
ochafik 2025-01-29 21:29:45 +00:00
parent ed7c622d78
commit 36c776f329
3 changed files with 66 additions and 61 deletions

View file

@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_inputs params;
params.messages = json::array({{
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
common_chat_params_init(chat_template, params);
common_chat_params_init(chat_template, inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
@ -1803,11 +1803,10 @@ std::string common_chat_apply_template(
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
common_chat_inputs params;
params.messages = messages;
params.add_generation_prompt = add_ass;
auto data = common_chat_params_init(tmpl, params);
return data.prompt;
common_chat_inputs inputs;
inputs.messages = messages;
inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl, inputs).prompt;
}
int alloc_size = 0;

View file

@ -1824,16 +1824,16 @@ struct server_context {
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs params;
params.messages = json::array({{
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, params);
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, params);
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
try {
common_chat_params chat_data;
common_chat_params chat_params;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) {
chat_data = common_chat_params_init(*tmpl, {
chat_params = common_chat_params_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()),
/* .tools = */ json_value(data, "tools", json()),
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
@ -3799,28 +3799,28 @@ int main(int argc, char ** argv) {
/* .stream = */ json_value(data, "stream", false),
/* .grammar = */ json_value(data, "grammar", std::string("")),
});
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
LOG_DBG("Prompt: %s\n", chat_data.prompt.get<std::string>().c_str());
LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str());
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
if (data.contains("grammar")) {
if (!chat_data.grammar.empty()) {
if (!chat_params.grammar.empty()) {
throw std::runtime_error("Cannot provide grammar and tools");
}
chat_data.grammar = data.at("grammar");
chat_params.grammar = data.at("grammar");
}
// TODO: move inside minja:chat_template?
add_special = tmpl->source().find("eos_token") == std::string::npos &&
tmpl->source().find("bos_token") == std::string::npos;
} else {
add_special = true;
chat_data.prompt = data.at("prompt");
chat_params.prompt = data.at("prompt");
if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar");
chat_params.grammar = data.at("grammar");
} else if (data.contains("json_schema")) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
@ -3838,9 +3838,9 @@ int main(int argc, char ** argv) {
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
task.params.sampling.grammar = chat_data.grammar;
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
for (const auto & trigger : chat_data.grammar_triggers) {
task.params.sampling.grammar = chat_params.grammar;
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
for (const auto & trigger : chat_params.grammar_triggers) {
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
@ -3850,8 +3850,8 @@ int main(int argc, char ** argv) {
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
task.params.sampling.grammar_trigger_words.push_back(trigger);
}
task.params.antiprompt = chat_data.additional_stops;
task.params.chat_parser = chat_data.parser;
task.params.antiprompt = chat_params.additional_stops;
task.params.chat_parser = chat_params.parser;
if (task.params.sampling.grammar_lazy) {
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
}

View file

@ -169,18 +169,18 @@ struct delta_data {
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
common_chat_inputs params;
params.parallel_tool_calls = true;
params.messages = json::array();
params.messages.push_back(user_message);
params.tools = tools;
auto prefix_data = common_chat_params_init(tmpl, params);
params.messages.push_back(delta_message);
params.add_generation_prompt = false;
auto full_data = common_chat_params_init(tmpl, params);
common_chat_inputs inputs;
inputs.parallel_tool_calls = true;
inputs.messages = json::array();
inputs.messages.push_back(user_message);
inputs.tools = tools;
auto params_prefix = common_chat_params_init(tmpl, inputs);
inputs.messages.push_back(delta_message);
inputs.add_generation_prompt = false;
auto params_full = common_chat_params_init(tmpl, inputs);
std::string prefix = prefix_data.prompt;
std::string full = full_data.prompt;
std::string prefix = params_prefix.prompt;
std::string full = params_full.prompt;
// Check full starts with prefix
if (full.find(prefix) != 0) {
@ -203,7 +203,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
break;
}
}
return {delta, full_data.grammar, full_data.parser};
return {delta, params_full.grammar, params_full.parser};
}
/*
@ -220,12 +220,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
};
for (const auto & tool_choice : json({"auto", "required"})) {
common_chat_inputs params;
params.tool_choice = tool_choice;
params.parallel_tool_calls = true;
params.messages = json {user_message, test_message};
params.tools = tools;
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta);
@ -309,17 +303,18 @@ static void test_template_output_parsers() {
tools_params.tools.push_back(special_function_tool);
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
auto data = common_chat_params_init(tmpl, params);
return data.format;
return common_chat_params_init(tmpl, params).format;
};
{
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end_of_turn>" };
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
@ -340,7 +335,8 @@ static void test_template_output_parsers() {
"}");
}
{
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "</s>" };
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
@ -351,12 +347,15 @@ static void test_template_output_parsers() {
/* skip_grammar_test= */ true);
}
{
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
@ -369,11 +368,13 @@ static void test_template_output_parsers() {
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@ -384,7 +385,8 @@ static void test_template_output_parsers() {
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
@ -395,7 +397,8 @@ static void test_template_output_parsers() {
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
@ -406,7 +409,8 @@ static void test_template_output_parsers() {
"<function=special_function>{\"arg1\": 1}</function>");
}
{
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
@ -420,7 +424,8 @@ static void test_template_output_parsers() {
"{\"arg1\": 1}");
}
{
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eot_id|>" };
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
@ -431,7 +436,8 @@ static void test_template_output_parsers() {
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
{
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file(
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end▁of▁sentence>" };
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));