Finish renaming of chat inputs vs. params [skip ci]
This commit is contained in:
parent
ed7c622d78
commit
36c776f329
3 changed files with 66 additions and 61 deletions
|
@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
common_chat_inputs params;
|
||||
params.messages = json::array({{
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
common_chat_params_init(chat_template, params);
|
||||
common_chat_params_init(chat_template, inputs);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
|
@ -1803,11 +1803,10 @@ std::string common_chat_apply_template(
|
|||
for (const auto & msg : msgs) {
|
||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
common_chat_inputs params;
|
||||
params.messages = messages;
|
||||
params.add_generation_prompt = add_ass;
|
||||
auto data = common_chat_params_init(tmpl, params);
|
||||
return data.prompt;
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.add_generation_prompt = add_ass;
|
||||
return common_chat_params_init(tmpl, inputs).prompt;
|
||||
}
|
||||
|
||||
int alloc_size = 0;
|
||||
|
|
|
@ -1824,16 +1824,16 @@ struct server_context {
|
|||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
common_chat_inputs params;
|
||||
params.messages = json::array({{
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
common_chat_params_init(*templates.template_default, params);
|
||||
common_chat_params_init(*templates.template_default, inputs);
|
||||
if (templates.template_tool_use) {
|
||||
common_chat_params_init(*templates.template_tool_use, params);
|
||||
common_chat_params_init(*templates.template_tool_use, inputs);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
|
@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
common_chat_params chat_data;
|
||||
common_chat_params chat_params;
|
||||
bool add_special = false;
|
||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||
chat_data = common_chat_params_init(*tmpl, {
|
||||
chat_params = common_chat_params_init(*tmpl, {
|
||||
/* .messages = */ json_value(data, "messages", json::array()),
|
||||
/* .tools = */ json_value(data, "tools", json()),
|
||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||
|
@ -3799,28 +3799,28 @@ int main(int argc, char ** argv) {
|
|||
/* .stream = */ json_value(data, "stream", false),
|
||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||
});
|
||||
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
|
||||
LOG_DBG("Prompt: %s\n", chat_data.prompt.get<std::string>().c_str());
|
||||
LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str());
|
||||
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
|
||||
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
|
||||
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
|
||||
if (data.contains("grammar")) {
|
||||
if (!chat_data.grammar.empty()) {
|
||||
if (!chat_params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot provide grammar and tools");
|
||||
}
|
||||
chat_data.grammar = data.at("grammar");
|
||||
chat_params.grammar = data.at("grammar");
|
||||
}
|
||||
// TODO: move inside minja:chat_template?
|
||||
add_special = tmpl->source().find("eos_token") == std::string::npos &&
|
||||
tmpl->source().find("bos_token") == std::string::npos;
|
||||
} else {
|
||||
add_special = true;
|
||||
chat_data.prompt = data.at("prompt");
|
||||
chat_params.prompt = data.at("prompt");
|
||||
if (data.contains("grammar")) {
|
||||
chat_data.grammar = data.at("grammar");
|
||||
chat_params.grammar = data.at("grammar");
|
||||
} else if (data.contains("json_schema")) {
|
||||
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||
}
|
||||
}
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
@ -3838,9 +3838,9 @@ int main(int argc, char ** argv) {
|
|||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.sampling.grammar = chat_data.grammar;
|
||||
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
|
||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||
task.params.sampling.grammar = chat_params.grammar;
|
||||
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
|
@ -3850,8 +3850,8 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
task.params.antiprompt = chat_data.additional_stops;
|
||||
task.params.chat_parser = chat_data.parser;
|
||||
task.params.antiprompt = chat_params.additional_stops;
|
||||
task.params.chat_parser = chat_params.parser;
|
||||
if (task.params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
|
||||
}
|
||||
|
|
|
@ -169,18 +169,18 @@ struct delta_data {
|
|||
};
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
common_chat_inputs params;
|
||||
params.parallel_tool_calls = true;
|
||||
params.messages = json::array();
|
||||
params.messages.push_back(user_message);
|
||||
params.tools = tools;
|
||||
auto prefix_data = common_chat_params_init(tmpl, params);
|
||||
params.messages.push_back(delta_message);
|
||||
params.add_generation_prompt = false;
|
||||
auto full_data = common_chat_params_init(tmpl, params);
|
||||
common_chat_inputs inputs;
|
||||
inputs.parallel_tool_calls = true;
|
||||
inputs.messages = json::array();
|
||||
inputs.messages.push_back(user_message);
|
||||
inputs.tools = tools;
|
||||
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
||||
inputs.messages.push_back(delta_message);
|
||||
inputs.add_generation_prompt = false;
|
||||
auto params_full = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
std::string prefix = prefix_data.prompt;
|
||||
std::string full = full_data.prompt;
|
||||
std::string prefix = params_prefix.prompt;
|
||||
std::string full = params_full.prompt;
|
||||
|
||||
// Check full starts with prefix
|
||||
if (full.find(prefix) != 0) {
|
||||
|
@ -203,7 +203,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|||
break;
|
||||
}
|
||||
}
|
||||
return {delta, full_data.grammar, full_data.parser};
|
||||
return {delta, params_full.grammar, params_full.parser};
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -220,12 +220,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
};
|
||||
|
||||
for (const auto & tool_choice : json({"auto", "required"})) {
|
||||
common_chat_inputs params;
|
||||
params.tool_choice = tool_choice;
|
||||
params.parallel_tool_calls = true;
|
||||
params.messages = json {user_message, test_message};
|
||||
params.tools = tools;
|
||||
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||
if (!expected_delta.empty()) {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
|
@ -309,17 +303,18 @@ static void test_template_output_parsers() {
|
|||
tools_params.tools.push_back(special_function_tool);
|
||||
|
||||
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||
auto data = common_chat_params_init(tmpl, params);
|
||||
return data.format;
|
||||
return common_chat_params_init(tmpl, params).format;
|
||||
};
|
||||
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
|
||||
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
|
||||
|
@ -340,7 +335,8 @@ static void test_template_output_parsers() {
|
|||
"}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "</s>" };
|
||||
|
||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -351,12 +347,15 @@ static void test_template_output_parsers() {
|
|||
/* skip_grammar_test= */ true);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
|
||||
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
|
@ -369,11 +368,13 @@ static void test_template_output_parsers() {
|
|||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
|
||||
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
|
@ -384,7 +385,8 @@ static void test_template_output_parsers() {
|
|||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -395,7 +397,8 @@ static void test_template_output_parsers() {
|
|||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -406,7 +409,8 @@ static void test_template_output_parsers() {
|
|||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||
|
@ -420,7 +424,8 @@ static void test_template_output_parsers() {
|
|||
"{\"arg1\": 1}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -431,7 +436,8 @@ static void test_template_output_parsers() {
|
|||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue