Finish renaming of chat inputs vs. params [skip ci]
This commit is contained in:
parent
ed7c622d78
commit
36c776f329
3 changed files with 66 additions and 61 deletions
|
@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
try {
|
try {
|
||||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||||
common_chat_inputs params;
|
common_chat_inputs inputs;
|
||||||
params.messages = json::array({{
|
inputs.messages = json::array({{
|
||||||
{"role", "user"},
|
{"role", "user"},
|
||||||
{"content", "test"},
|
{"content", "test"},
|
||||||
}});
|
}});
|
||||||
common_chat_params_init(chat_template, params);
|
common_chat_params_init(chat_template, inputs);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||||
|
@ -1803,11 +1803,10 @@ std::string common_chat_apply_template(
|
||||||
for (const auto & msg : msgs) {
|
for (const auto & msg : msgs) {
|
||||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||||
}
|
}
|
||||||
common_chat_inputs params;
|
common_chat_inputs inputs;
|
||||||
params.messages = messages;
|
inputs.messages = messages;
|
||||||
params.add_generation_prompt = add_ass;
|
inputs.add_generation_prompt = add_ass;
|
||||||
auto data = common_chat_params_init(tmpl, params);
|
return common_chat_params_init(tmpl, inputs).prompt;
|
||||||
return data.prompt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
|
|
|
@ -1824,16 +1824,16 @@ struct server_context {
|
||||||
|
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
auto templates = common_chat_templates_from_model(model, "");
|
auto templates = common_chat_templates_from_model(model, "");
|
||||||
common_chat_inputs params;
|
common_chat_inputs inputs;
|
||||||
params.messages = json::array({{
|
inputs.messages = json::array({{
|
||||||
{"role", "user"},
|
{"role", "user"},
|
||||||
{"content", "test"},
|
{"content", "test"},
|
||||||
}});
|
}});
|
||||||
GGML_ASSERT(templates.template_default);
|
GGML_ASSERT(templates.template_default);
|
||||||
try {
|
try {
|
||||||
common_chat_params_init(*templates.template_default, params);
|
common_chat_params_init(*templates.template_default, inputs);
|
||||||
if (templates.template_tool_use) {
|
if (templates.template_tool_use) {
|
||||||
common_chat_params_init(*templates.template_tool_use, params);
|
common_chat_params_init(*templates.template_tool_use, inputs);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
|
@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
common_chat_params chat_data;
|
common_chat_params chat_params;
|
||||||
bool add_special = false;
|
bool add_special = false;
|
||||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||||
chat_data = common_chat_params_init(*tmpl, {
|
chat_params = common_chat_params_init(*tmpl, {
|
||||||
/* .messages = */ json_value(data, "messages", json::array()),
|
/* .messages = */ json_value(data, "messages", json::array()),
|
||||||
/* .tools = */ json_value(data, "tools", json()),
|
/* .tools = */ json_value(data, "tools", json()),
|
||||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||||
|
@ -3799,28 +3799,28 @@ int main(int argc, char ** argv) {
|
||||||
/* .stream = */ json_value(data, "stream", false),
|
/* .stream = */ json_value(data, "stream", false),
|
||||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||||
});
|
});
|
||||||
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
|
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
|
||||||
LOG_DBG("Prompt: %s\n", chat_data.prompt.get<std::string>().c_str());
|
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
|
||||||
LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str());
|
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
|
||||||
if (data.contains("grammar")) {
|
if (data.contains("grammar")) {
|
||||||
if (!chat_data.grammar.empty()) {
|
if (!chat_params.grammar.empty()) {
|
||||||
throw std::runtime_error("Cannot provide grammar and tools");
|
throw std::runtime_error("Cannot provide grammar and tools");
|
||||||
}
|
}
|
||||||
chat_data.grammar = data.at("grammar");
|
chat_params.grammar = data.at("grammar");
|
||||||
}
|
}
|
||||||
// TODO: move inside minja:chat_template?
|
// TODO: move inside minja:chat_template?
|
||||||
add_special = tmpl->source().find("eos_token") == std::string::npos &&
|
add_special = tmpl->source().find("eos_token") == std::string::npos &&
|
||||||
tmpl->source().find("bos_token") == std::string::npos;
|
tmpl->source().find("bos_token") == std::string::npos;
|
||||||
} else {
|
} else {
|
||||||
add_special = true;
|
add_special = true;
|
||||||
chat_data.prompt = data.at("prompt");
|
chat_params.prompt = data.at("prompt");
|
||||||
if (data.contains("grammar")) {
|
if (data.contains("grammar")) {
|
||||||
chat_data.grammar = data.at("grammar");
|
chat_params.grammar = data.at("grammar");
|
||||||
} else if (data.contains("json_schema")) {
|
} else if (data.contains("json_schema")) {
|
||||||
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
|
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
|
||||||
tasks.reserve(tokenized_prompts.size());
|
tasks.reserve(tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
@ -3838,9 +3838,9 @@ int main(int argc, char ** argv) {
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.oaicompat = oaicompat;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
task.params.sampling.grammar = chat_data.grammar;
|
task.params.sampling.grammar = chat_params.grammar;
|
||||||
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
|
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
|
||||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||||
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||||
if (ids.size() == 1) {
|
if (ids.size() == 1) {
|
||||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||||
|
@ -3850,8 +3850,8 @@ int main(int argc, char ** argv) {
|
||||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||||
}
|
}
|
||||||
task.params.antiprompt = chat_data.additional_stops;
|
task.params.antiprompt = chat_params.additional_stops;
|
||||||
task.params.chat_parser = chat_data.parser;
|
task.params.chat_parser = chat_params.parser;
|
||||||
if (task.params.sampling.grammar_lazy) {
|
if (task.params.sampling.grammar_lazy) {
|
||||||
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
|
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,18 +169,18 @@ struct delta_data {
|
||||||
};
|
};
|
||||||
|
|
||||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
common_chat_inputs params;
|
common_chat_inputs inputs;
|
||||||
params.parallel_tool_calls = true;
|
inputs.parallel_tool_calls = true;
|
||||||
params.messages = json::array();
|
inputs.messages = json::array();
|
||||||
params.messages.push_back(user_message);
|
inputs.messages.push_back(user_message);
|
||||||
params.tools = tools;
|
inputs.tools = tools;
|
||||||
auto prefix_data = common_chat_params_init(tmpl, params);
|
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
||||||
params.messages.push_back(delta_message);
|
inputs.messages.push_back(delta_message);
|
||||||
params.add_generation_prompt = false;
|
inputs.add_generation_prompt = false;
|
||||||
auto full_data = common_chat_params_init(tmpl, params);
|
auto params_full = common_chat_params_init(tmpl, inputs);
|
||||||
|
|
||||||
std::string prefix = prefix_data.prompt;
|
std::string prefix = params_prefix.prompt;
|
||||||
std::string full = full_data.prompt;
|
std::string full = params_full.prompt;
|
||||||
|
|
||||||
// Check full starts with prefix
|
// Check full starts with prefix
|
||||||
if (full.find(prefix) != 0) {
|
if (full.find(prefix) != 0) {
|
||||||
|
@ -203,7 +203,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {delta, full_data.grammar, full_data.parser};
|
return {delta, params_full.grammar, params_full.parser};
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -220,12 +220,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const auto & tool_choice : json({"auto", "required"})) {
|
for (const auto & tool_choice : json({"auto", "required"})) {
|
||||||
common_chat_inputs params;
|
|
||||||
params.tool_choice = tool_choice;
|
|
||||||
params.parallel_tool_calls = true;
|
|
||||||
params.messages = json {user_message, test_message};
|
|
||||||
params.tools = tools;
|
|
||||||
|
|
||||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||||
if (!expected_delta.empty()) {
|
if (!expected_delta.empty()) {
|
||||||
assert_equals(expected_delta, data.delta);
|
assert_equals(expected_delta, data.delta);
|
||||||
|
@ -309,17 +303,18 @@ static void test_template_output_parsers() {
|
||||||
tools_params.tools.push_back(special_function_tool);
|
tools_params.tools.push_back(special_function_tool);
|
||||||
|
|
||||||
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
|
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||||
auto data = common_chat_params_init(tmpl, params);
|
return common_chat_params_init(tmpl, params).format;
|
||||||
return data.format;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||||
|
|
||||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
|
||||||
|
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
|
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
|
||||||
|
@ -340,7 +335,8 @@ static void test_template_output_parsers() {
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "</s>" };
|
std::vector<std::string> end_tokens { "</s>" };
|
||||||
|
|
||||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -351,12 +347,15 @@ static void test_template_output_parsers() {
|
||||||
/* skip_grammar_test= */ true);
|
/* skip_grammar_test= */ true);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
|
||||||
|
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
@ -369,11 +368,13 @@ static void test_template_output_parsers() {
|
||||||
"</tool_call>");
|
"</tool_call>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
|
||||||
|
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||||
|
@ -384,7 +385,8 @@ static void test_template_output_parsers() {
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -395,7 +397,8 @@ static void test_template_output_parsers() {
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -406,7 +409,8 @@ static void test_template_output_parsers() {
|
||||||
"<function=special_function>{\"arg1\": 1}</function>");
|
"<function=special_function>{\"arg1\": 1}</function>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||||
|
@ -420,7 +424,8 @@ static void test_template_output_parsers() {
|
||||||
"{\"arg1\": 1}");
|
"{\"arg1\": 1}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -431,7 +436,8 @@ static void test_template_output_parsers() {
|
||||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file(
|
||||||
|
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue