Rename: common/chat.*, common_chat_{inputs -> params}

This commit is contained in:
ochafik 2025-01-29 21:18:49 +00:00
parent 6e676c8030
commit ed7c622d78
9 changed files with 117 additions and 123 deletions

View file

@ -52,7 +52,7 @@ TEST_TARGETS = \
tests/test-arg-parser \
tests/test-autorelease \
tests/test-backend-ops \
tests/test-chat-handler \
tests/test-chat \
tests/test-chat-template \
tests/test-double-float \
tests/test-grammar-integration \
@ -984,7 +984,7 @@ OBJ_COMMON = \
$(DIR_COMMON)/ngram-cache.o \
$(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/chat-handler.o \
$(DIR_COMMON)/chat.o \
$(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/json-schema-to-grammar.o
@ -1363,8 +1363,8 @@ llama-server: \
examples/server/httplib.h \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat-handler.cpp \
common/chat-handler.hpp \
common/chat.cpp \
common/chat.hpp \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-chat-handler: tests/test-chat-handler.cpp \
tests/test-chat: tests/test-chat.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -56,8 +56,8 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-handler.cpp
chat-handler.hpp
chat.cpp
chat.hpp
chat-template.hpp
common.cpp
common.h

View file

@ -1,4 +1,4 @@
#include "chat-handler.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include "json-schema-to-grammar.h"
#include "log.h"
@ -170,11 +170,11 @@ static common_chat_msg no_op_text_parser(const std::string & input) {
};
}
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
auto tool_call_schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
auto tool_schema = json {
{"type", "object"},
@ -190,7 +190,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
if (function.contains("description")) {
tool_schema["description"] = function["description"];
}
if (params.parallel_tool_calls) {
if (inputs.parallel_tool_calls) {
tool_schema["properties"]["id"] = {
{"type", "string"},
{"minLength", 4},
@ -200,7 +200,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
tool_call_schemas.emplace_back(tool_schema);
});
const auto tool_call =
params.parallel_tool_calls
inputs.parallel_tool_calls
? json {
{"type", "object"},
{"properties", {
@ -224,16 +224,16 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
{"required", json::array({"tool_call"})},
};
const auto schema =
params.tool_choice != "required"
inputs.tool_choice != "required"
? json {
{"anyOf", json::array({
tool_call,
{
{"type", "object"},
{"properties", {
{"response", params.json_schema.is_null()
{"response", inputs.json_schema.is_null()
? json {{"type", "string"}}
: params.json_schema
: inputs.json_schema
},
}},
{"required", json::array({"response"})},
@ -248,10 +248,10 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
}, grammar_options);
auto tweaked_messages = common_chat_template::add_system(
params.messages,
inputs.messages,
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "generic tool calls";
data.parser = [&](const std::string & input) {
json data = json::parse(input);
@ -280,12 +280,12 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
return data;
}
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
@ -311,13 +311,13 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!params.parallel_tool_calls) {
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}, grammar_options);
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "mistral nemo tool calls";
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
@ -344,10 +344,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
}
}
static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) {
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
@ -380,7 +380,7 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
};
auto has_function = false;
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
@ -410,12 +410,12 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
{"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
});
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
@ -447,17 +447,17 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
};
}
}
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
};
return data;
}
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
@ -466,27 +466,27 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "deepseek r1 tool calls";
data.parser = [params](const std::string & input) {
data.parser = [inputs](const std::string & input) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
};
return data;
}
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
fprintf(stderr, "%s\n", __func__);
common_chat_data data;
if (!params.tools.is_null() && !params.tools.empty()) {
data.grammar_lazy = params.tool_choice != "required";
common_chat_params data;
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
@ -505,7 +505,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!params.parallel_tool_calls) {
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
@ -519,24 +519,24 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
data.parser = no_op_text_parser;
data.format = "firefunction v2 text-only";
}
data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, {
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(params.tools.empty() ? "" : params.tools.dump(2))},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}, /* adjust_inputs= */ false);
return data;
}
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data;
common_chat_params data;
if (!params.tools.is_null() && !params.tools.empty()) {
data.grammar_lazy = params.tool_choice != "required";
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
@ -547,7 +547,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
});
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (params.parallel_tool_calls) {
if (inputs.parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else {
@ -560,12 +560,12 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
data.format = "functionary v3.2 content-only";
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = [params](const std::string & input) {
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.parser = [inputs](const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
if (res.content.find("all\n") == 0) {
res.content = res.content.substr(4);
}
@ -574,17 +574,17 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
return data;
}
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_data data;
json tools = params.tools.is_null() ? params.tools : json::array();
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = params.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
const auto & parameters = function["parameters"];
std::string name = function["name"];
@ -618,13 +618,13 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "functionary v3.1 llama 3.1 tool calls";
data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
@ -644,18 +644,18 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
};
return data;
}
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = params.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
@ -670,11 +670,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
}));
});
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "hermes 2 pro tool calls";
data.parser = [&](const std::string & input) -> common_chat_msg {
try {
@ -733,60 +733,60 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
return data;
}
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "content-only";
data.parser = no_op_text_parser;
data.grammar_lazy = false;
if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) {
if (!inputs.json_schema.is_null()) {
if (!inputs.grammar.empty()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
data.grammar = json_schema_to_grammar(params.json_schema);
data.grammar = json_schema_to_grammar(inputs.json_schema);
} else {
data.grammar = params.grammar.empty();
data.grammar = inputs.grammar.empty();
}
return data;
}
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
if (has_tools && !params.grammar.empty()) {
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none";
if (has_tools && !inputs.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
const auto & src = tmpl.source();
if (src.find(">>>all") != std::string::npos) {
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
}
if (src.find(" functools[") != std::string::npos) {
// Firefunction v2 requires datetime and functions in the context
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs);
}
if (has_tools) {
return common_chat_init_without_tools(tmpl, params);
return common_chat_params_init_without_tools(tmpl, inputs);
}
if (src.find("<tool_call>") != std::string::npos) {
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs);
}
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
return common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(tmpl, inputs);
}
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
}
if (src.find("<tool▁calls▁begin>") != std::string::npos) {
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs);
}
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs);
}
return common_chat_init_generic_tool_call(tmpl, params);
return common_chat_params_init_generic_tool_call(tmpl, inputs);
}

View file

@ -1,11 +1,5 @@
/*
Copyright 2024 Google LLC
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "common.h"
@ -16,7 +10,7 @@
using json = nlohmann::ordered_json;
struct common_chat_params {
struct common_chat_inputs {
json messages;
json tools;
json tool_choice;
@ -29,7 +23,7 @@ struct common_chat_params {
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
struct common_chat_data {
struct common_chat_params {
json prompt;
std::string grammar;
std::vector<common_grammar_trigger> grammar_triggers;
@ -39,4 +33,4 @@ struct common_chat_data {
bool grammar_lazy = false;
};
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);

View file

@ -12,7 +12,7 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat-handler.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include <algorithm>
@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_params params;
common_chat_inputs params;
params.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
common_chat_init(chat_template, params);
common_chat_params_init(chat_template, params);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
@ -1803,10 +1803,10 @@ std::string common_chat_apply_template(
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
common_chat_params params;
common_chat_inputs params;
params.messages = messages;
params.add_generation_prompt = add_ass;
auto data = common_chat_init(tmpl, params);
auto data = common_chat_params_init(tmpl, params);
return data.prompt;
}

View file

@ -1824,16 +1824,16 @@ struct server_context {
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_params params;
common_chat_inputs params;
params.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_init(*templates.template_default, params);
common_chat_params_init(*templates.template_default, params);
if (templates.template_tool_use) {
common_chat_init(*templates.template_tool_use, params);
common_chat_params_init(*templates.template_tool_use, params);
}
return true;
} catch (const std::exception & e) {
@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
try {
common_chat_data chat_data;
common_chat_params chat_data;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) {
chat_data = common_chat_init(*tmpl, {
chat_data = common_chat_params_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()),
/* .tools = */ json_value(data, "tools", json()),
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),

View file

@ -17,7 +17,7 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat-handler.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include <random>

View file

@ -93,7 +93,7 @@ if (NOT WIN32)
llama_target_and_test(test-grammar-parser.cpp)
llama_target_and_test(test-grammar-integration.cpp)
llama_target_and_test(test-llama-grammar.cpp)
llama_target_and_test(test-chat-handler.cpp)
llama_target_and_test(test-chat.cpp)
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)

View file

@ -1,4 +1,4 @@
#include "chat-handler.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include "llama-grammar.h"
#include "unicode.h"
@ -169,15 +169,15 @@ struct delta_data {
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
common_chat_params params;
common_chat_inputs params;
params.parallel_tool_calls = true;
params.messages = json::array();
params.messages.push_back(user_message);
params.tools = tools;
auto prefix_data = common_chat_init(tmpl, params);
auto prefix_data = common_chat_params_init(tmpl, params);
params.messages.push_back(delta_message);
params.add_generation_prompt = false;
auto full_data = common_chat_init(tmpl, params);
auto full_data = common_chat_params_init(tmpl, params);
std::string prefix = prefix_data.prompt;
std::string full = full_data.prompt;
@ -220,7 +220,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
};
for (const auto & tool_choice : json({"auto", "required"})) {
common_chat_params params;
common_chat_inputs params;
params.tool_choice = tool_choice;
params.parallel_tool_calls = true;
params.messages = json {user_message, test_message};
@ -301,15 +301,15 @@ static void test_template_output_parsers() {
};
common_chat_params no_tools_params;
common_chat_inputs no_tools_params;
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
common_chat_params tools_params = no_tools_params;
common_chat_inputs tools_params = no_tools_params;
tools_params.tools = json::array();
tools_params.tools.push_back(special_function_tool);
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
auto data = common_chat_init(tmpl, params);
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
auto data = common_chat_params_init(tmpl, params);
return data.format;
};
@ -322,7 +322,7 @@ static void test_template_output_parsers() {
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
"{\n"
" \"response\": \"Hello, world!\"\n"
"}"));