Rename: common/chat.*, common_chat_{inputs -> params}
This commit is contained in:
parent
6e676c8030
commit
ed7c622d78
9 changed files with 117 additions and 123 deletions
10
Makefile
10
Makefile
|
@ -52,7 +52,7 @@ TEST_TARGETS = \
|
|||
tests/test-arg-parser \
|
||||
tests/test-autorelease \
|
||||
tests/test-backend-ops \
|
||||
tests/test-chat-handler \
|
||||
tests/test-chat \
|
||||
tests/test-chat-template \
|
||||
tests/test-double-float \
|
||||
tests/test-grammar-integration \
|
||||
|
@ -984,7 +984,7 @@ OBJ_COMMON = \
|
|||
$(DIR_COMMON)/ngram-cache.o \
|
||||
$(DIR_COMMON)/sampling.o \
|
||||
$(DIR_COMMON)/speculative.o \
|
||||
$(DIR_COMMON)/chat-handler.o \
|
||||
$(DIR_COMMON)/chat.o \
|
||||
$(DIR_COMMON)/build-info.o \
|
||||
$(DIR_COMMON)/json-schema-to-grammar.o
|
||||
|
||||
|
@ -1363,8 +1363,8 @@ llama-server: \
|
|||
examples/server/httplib.h \
|
||||
examples/server/index.html.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
common/chat-handler.cpp \
|
||||
common/chat-handler.hpp \
|
||||
common/chat.cpp \
|
||||
common/chat.hpp \
|
||||
common/chat-template.hpp \
|
||||
common/json.hpp \
|
||||
common/minja.hpp \
|
||||
|
@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
|
|||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-chat-handler: tests/test-chat-handler.cpp \
|
||||
tests/test-chat: tests/test-chat.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
|
|
@ -56,8 +56,8 @@ add_library(${TARGET} STATIC
|
|||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-handler.cpp
|
||||
chat-handler.hpp
|
||||
chat.cpp
|
||||
chat.hpp
|
||||
chat-template.hpp
|
||||
common.cpp
|
||||
common.h
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#include "chat-handler.hpp"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
|
@ -170,11 +170,11 @@ static common_chat_msg no_op_text_parser(const std::string & input) {
|
|||
};
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
|
@ -190,7 +190,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
if (function.contains("description")) {
|
||||
tool_schema["description"] = function["description"];
|
||||
}
|
||||
if (params.parallel_tool_calls) {
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_schema["properties"]["id"] = {
|
||||
{"type", "string"},
|
||||
{"minLength", 4},
|
||||
|
@ -200,7 +200,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
tool_call_schemas.emplace_back(tool_schema);
|
||||
});
|
||||
const auto tool_call =
|
||||
params.parallel_tool_calls
|
||||
inputs.parallel_tool_calls
|
||||
? json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
|
@ -224,16 +224,16 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
{"required", json::array({"tool_call"})},
|
||||
};
|
||||
const auto schema =
|
||||
params.tool_choice != "required"
|
||||
inputs.tool_choice != "required"
|
||||
? json {
|
||||
{"anyOf", json::array({
|
||||
tool_call,
|
||||
{
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"response", params.json_schema.is_null()
|
||||
{"response", inputs.json_schema.is_null()
|
||||
? json {{"type", "string"}}
|
||||
: params.json_schema
|
||||
: inputs.json_schema
|
||||
},
|
||||
}},
|
||||
{"required", json::array({"response"})},
|
||||
|
@ -248,10 +248,10 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
}, grammar_options);
|
||||
|
||||
auto tweaked_messages = common_chat_template::add_system(
|
||||
params.messages,
|
||||
inputs.messages,
|
||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||
|
||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "generic tool calls";
|
||||
data.parser = [&](const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
|
@ -280,12 +280,12 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
|
@ -311,13 +311,13 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!params.parallel_tool_calls) {
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "mistral nemo tool calls";
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
|
@ -344,10 +344,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|||
}
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) {
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
|
@ -380,7 +380,7 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
|
|||
};
|
||||
|
||||
auto has_function = false;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
@ -410,12 +410,12 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
|
|||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
|
||||
data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
||||
data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
|
@ -447,17 +447,17 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
|
|||
};
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
@ -466,27 +466,27 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
|||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
|
||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "deepseek r1 tool calls";
|
||||
data.parser = [params](const std::string & input) {
|
||||
data.parser = [inputs](const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||
return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
fprintf(stderr, "%s\n", __func__);
|
||||
common_chat_data data;
|
||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
common_chat_params data;
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
|
@ -505,7 +505,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
|||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!params.parallel_tool_calls) {
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
|
@ -519,24 +519,24 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
|||
data.parser = no_op_text_parser;
|
||||
data.format = "firefunction v2 text-only";
|
||||
}
|
||||
data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, {
|
||||
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(params.tools.empty() ? "" : params.tools.dump(2))},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
}, /* adjust_inputs= */ false);
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_data data;
|
||||
common_chat_params data;
|
||||
|
||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
@ -547,7 +547,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
});
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (params.parallel_tool_calls) {
|
||||
if (inputs.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||
} else {
|
||||
|
@ -560,12 +560,12 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
data.format = "functionary v3.2 content-only";
|
||||
}
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.parser = [params](const std::string & input) {
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.parser = [inputs](const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||
auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||
if (res.content.find("all\n") == 0) {
|
||||
res.content = res.content.substr(4);
|
||||
}
|
||||
|
@ -574,17 +574,17 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_data data;
|
||||
json tools = params.tools.is_null() ? params.tools : json::array();
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & parameters = function["parameters"];
|
||||
std::string name = function["name"];
|
||||
|
@ -618,13 +618,13 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
|
|||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "functionary v3.1 llama 3.1 tool calls";
|
||||
data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
|
@ -644,18 +644,18 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
|
|||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
@ -670,11 +670,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
}));
|
||||
});
|
||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "hermes 2 pro tool calls";
|
||||
data.parser = [&](const std::string & input) -> common_chat_msg {
|
||||
try {
|
||||
|
@ -733,60 +733,60 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "content-only";
|
||||
data.parser = no_op_text_parser;
|
||||
data.grammar_lazy = false;
|
||||
if (!params.json_schema.is_null()) {
|
||||
if (!params.grammar.empty()) {
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
if (!inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
data.grammar = json_schema_to_grammar(params.json_schema);
|
||||
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||
} else {
|
||||
data.grammar = params.grammar.empty();
|
||||
data.grammar = inputs.grammar.empty();
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
|
||||
if (has_tools && !params.grammar.empty()) {
|
||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none";
|
||||
if (has_tools && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
|
||||
const auto & src = tmpl.source();
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
||||
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
||||
}
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
// Firefunction v2 requires datetime and functions in the context
|
||||
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
|
||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||
return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (has_tools) {
|
||||
return common_chat_init_without_tools(tmpl, params);
|
||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
||||
return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||
}
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
|
||||
return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs);
|
||||
}
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
|
||||
return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs);
|
||||
}
|
||||
return common_chat_init_generic_tool_call(tmpl, params);
|
||||
return common_chat_params_init_generic_tool_call(tmpl, inputs);
|
||||
}
|
||||
|
|
@ -1,11 +1,5 @@
|
|||
/*
|
||||
Copyright 2024 Google LLC
|
||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
@ -16,7 +10,7 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct common_chat_params {
|
||||
struct common_chat_inputs {
|
||||
json messages;
|
||||
json tools;
|
||||
json tool_choice;
|
||||
|
@ -29,7 +23,7 @@ struct common_chat_params {
|
|||
|
||||
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
|
||||
|
||||
struct common_chat_data {
|
||||
struct common_chat_params {
|
||||
json prompt;
|
||||
std::string grammar;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
|
@ -39,4 +33,4 @@ struct common_chat_data {
|
|||
bool grammar_lazy = false;
|
||||
};
|
||||
|
||||
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
||||
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
|
@ -12,7 +12,7 @@
|
|||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-handler.hpp"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
common_chat_params params;
|
||||
common_chat_inputs params;
|
||||
params.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
common_chat_init(chat_template, params);
|
||||
common_chat_params_init(chat_template, params);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
|
@ -1803,10 +1803,10 @@ std::string common_chat_apply_template(
|
|||
for (const auto & msg : msgs) {
|
||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
common_chat_params params;
|
||||
common_chat_inputs params;
|
||||
params.messages = messages;
|
||||
params.add_generation_prompt = add_ass;
|
||||
auto data = common_chat_init(tmpl, params);
|
||||
auto data = common_chat_params_init(tmpl, params);
|
||||
return data.prompt;
|
||||
}
|
||||
|
||||
|
|
|
@ -1824,16 +1824,16 @@ struct server_context {
|
|||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
common_chat_params params;
|
||||
common_chat_inputs params;
|
||||
params.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
common_chat_init(*templates.template_default, params);
|
||||
common_chat_params_init(*templates.template_default, params);
|
||||
if (templates.template_tool_use) {
|
||||
common_chat_init(*templates.template_tool_use, params);
|
||||
common_chat_params_init(*templates.template_tool_use, params);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
|
@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
common_chat_data chat_data;
|
||||
common_chat_params chat_data;
|
||||
bool add_special = false;
|
||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||
chat_data = common_chat_init(*tmpl, {
|
||||
chat_data = common_chat_params_init(*tmpl, {
|
||||
/* .messages = */ json_value(data, "messages", json::array()),
|
||||
/* .tools = */ json_value(data, "tools", json()),
|
||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "chat-handler.hpp"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <random>
|
||||
|
|
|
@ -93,7 +93,7 @@ if (NOT WIN32)
|
|||
llama_target_and_test(test-grammar-parser.cpp)
|
||||
llama_target_and_test(test-grammar-integration.cpp)
|
||||
llama_target_and_test(test-llama-grammar.cpp)
|
||||
llama_target_and_test(test-chat-handler.cpp)
|
||||
llama_target_and_test(test-chat.cpp)
|
||||
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#include "chat-handler.hpp"
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "llama-grammar.h"
|
||||
#include "unicode.h"
|
||||
|
@ -169,15 +169,15 @@ struct delta_data {
|
|||
};
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
common_chat_params params;
|
||||
common_chat_inputs params;
|
||||
params.parallel_tool_calls = true;
|
||||
params.messages = json::array();
|
||||
params.messages.push_back(user_message);
|
||||
params.tools = tools;
|
||||
auto prefix_data = common_chat_init(tmpl, params);
|
||||
auto prefix_data = common_chat_params_init(tmpl, params);
|
||||
params.messages.push_back(delta_message);
|
||||
params.add_generation_prompt = false;
|
||||
auto full_data = common_chat_init(tmpl, params);
|
||||
auto full_data = common_chat_params_init(tmpl, params);
|
||||
|
||||
std::string prefix = prefix_data.prompt;
|
||||
std::string full = full_data.prompt;
|
||||
|
@ -220,7 +220,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
};
|
||||
|
||||
for (const auto & tool_choice : json({"auto", "required"})) {
|
||||
common_chat_params params;
|
||||
common_chat_inputs params;
|
||||
params.tool_choice = tool_choice;
|
||||
params.parallel_tool_calls = true;
|
||||
params.messages = json {user_message, test_message};
|
||||
|
@ -301,15 +301,15 @@ static void test_template_output_parsers() {
|
|||
};
|
||||
|
||||
|
||||
common_chat_params no_tools_params;
|
||||
common_chat_inputs no_tools_params;
|
||||
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||
|
||||
common_chat_params tools_params = no_tools_params;
|
||||
common_chat_inputs tools_params = no_tools_params;
|
||||
tools_params.tools = json::array();
|
||||
tools_params.tools.push_back(special_function_tool);
|
||||
|
||||
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
|
||||
auto data = common_chat_init(tmpl, params);
|
||||
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||
auto data = common_chat_params_init(tmpl, params);
|
||||
return data.format;
|
||||
};
|
||||
|
||||
|
@ -322,7 +322,7 @@ static void test_template_output_parsers() {
|
|||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
|
||||
"{\n"
|
||||
" \"response\": \"Hello, world!\"\n"
|
||||
"}"));
|
Loading…
Add table
Add a link
Reference in a new issue