Rename: common/chat.*, common_chat_{inputs -> params}
This commit is contained in:
parent
6e676c8030
commit
ed7c622d78
9 changed files with 117 additions and 123 deletions
10
Makefile
10
Makefile
|
@ -52,7 +52,7 @@ TEST_TARGETS = \
|
||||||
tests/test-arg-parser \
|
tests/test-arg-parser \
|
||||||
tests/test-autorelease \
|
tests/test-autorelease \
|
||||||
tests/test-backend-ops \
|
tests/test-backend-ops \
|
||||||
tests/test-chat-handler \
|
tests/test-chat \
|
||||||
tests/test-chat-template \
|
tests/test-chat-template \
|
||||||
tests/test-double-float \
|
tests/test-double-float \
|
||||||
tests/test-grammar-integration \
|
tests/test-grammar-integration \
|
||||||
|
@ -984,7 +984,7 @@ OBJ_COMMON = \
|
||||||
$(DIR_COMMON)/ngram-cache.o \
|
$(DIR_COMMON)/ngram-cache.o \
|
||||||
$(DIR_COMMON)/sampling.o \
|
$(DIR_COMMON)/sampling.o \
|
||||||
$(DIR_COMMON)/speculative.o \
|
$(DIR_COMMON)/speculative.o \
|
||||||
$(DIR_COMMON)/chat-handler.o \
|
$(DIR_COMMON)/chat.o \
|
||||||
$(DIR_COMMON)/build-info.o \
|
$(DIR_COMMON)/build-info.o \
|
||||||
$(DIR_COMMON)/json-schema-to-grammar.o
|
$(DIR_COMMON)/json-schema-to-grammar.o
|
||||||
|
|
||||||
|
@ -1363,8 +1363,8 @@ llama-server: \
|
||||||
examples/server/httplib.h \
|
examples/server/httplib.h \
|
||||||
examples/server/index.html.hpp \
|
examples/server/index.html.hpp \
|
||||||
examples/server/loading.html.hpp \
|
examples/server/loading.html.hpp \
|
||||||
common/chat-handler.cpp \
|
common/chat.cpp \
|
||||||
common/chat-handler.hpp \
|
common/chat.hpp \
|
||||||
common/chat-template.hpp \
|
common/chat-template.hpp \
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
common/minja.hpp \
|
common/minja.hpp \
|
||||||
|
@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
|
||||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
tests/test-chat-handler: tests/test-chat-handler.cpp \
|
tests/test-chat: tests/test-chat.cpp \
|
||||||
$(OBJ_ALL)
|
$(OBJ_ALL)
|
||||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
|
@ -56,8 +56,8 @@ add_library(${TARGET} STATIC
|
||||||
arg.cpp
|
arg.cpp
|
||||||
arg.h
|
arg.h
|
||||||
base64.hpp
|
base64.hpp
|
||||||
chat-handler.cpp
|
chat.cpp
|
||||||
chat-handler.hpp
|
chat.hpp
|
||||||
chat-template.hpp
|
chat-template.hpp
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#include "chat-handler.hpp"
|
#include "chat.hpp"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
@ -170,11 +170,11 @@ static common_chat_msg no_op_text_parser(const std::string & input) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
|
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
auto tool_schema = json {
|
auto tool_schema = json {
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -190,7 +190,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
if (function.contains("description")) {
|
if (function.contains("description")) {
|
||||||
tool_schema["description"] = function["description"];
|
tool_schema["description"] = function["description"];
|
||||||
}
|
}
|
||||||
if (params.parallel_tool_calls) {
|
if (inputs.parallel_tool_calls) {
|
||||||
tool_schema["properties"]["id"] = {
|
tool_schema["properties"]["id"] = {
|
||||||
{"type", "string"},
|
{"type", "string"},
|
||||||
{"minLength", 4},
|
{"minLength", 4},
|
||||||
|
@ -200,7 +200,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
tool_call_schemas.emplace_back(tool_schema);
|
tool_call_schemas.emplace_back(tool_schema);
|
||||||
});
|
});
|
||||||
const auto tool_call =
|
const auto tool_call =
|
||||||
params.parallel_tool_calls
|
inputs.parallel_tool_calls
|
||||||
? json {
|
? json {
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
{"properties", {
|
{"properties", {
|
||||||
|
@ -224,16 +224,16 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
{"required", json::array({"tool_call"})},
|
{"required", json::array({"tool_call"})},
|
||||||
};
|
};
|
||||||
const auto schema =
|
const auto schema =
|
||||||
params.tool_choice != "required"
|
inputs.tool_choice != "required"
|
||||||
? json {
|
? json {
|
||||||
{"anyOf", json::array({
|
{"anyOf", json::array({
|
||||||
tool_call,
|
tool_call,
|
||||||
{
|
{
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
{"properties", {
|
{"properties", {
|
||||||
{"response", params.json_schema.is_null()
|
{"response", inputs.json_schema.is_null()
|
||||||
? json {{"type", "string"}}
|
? json {{"type", "string"}}
|
||||||
: params.json_schema
|
: inputs.json_schema
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
{"required", json::array({"response"})},
|
{"required", json::array({"response"})},
|
||||||
|
@ -248,10 +248,10 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
auto tweaked_messages = common_chat_template::add_system(
|
auto tweaked_messages = common_chat_template::add_system(
|
||||||
params.messages,
|
inputs.messages,
|
||||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||||
|
|
||||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "generic tool calls";
|
data.format = "generic tool calls";
|
||||||
data.parser = [&](const std::string & input) {
|
data.parser = [&](const std::string & input) {
|
||||||
json data = json::parse(input);
|
json data = json::parse(input);
|
||||||
|
@ -280,12 +280,12 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
schemas.push_back({
|
schemas.push_back({
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -311,13 +311,13 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
||||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
{"minItems", 1},
|
{"minItems", 1},
|
||||||
};
|
};
|
||||||
if (!params.parallel_tool_calls) {
|
if (!inputs.parallel_tool_calls) {
|
||||||
schema["maxItems"] = 1;
|
schema["maxItems"] = 1;
|
||||||
}
|
}
|
||||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "mistral nemo tool calls";
|
data.format = "mistral nemo tool calls";
|
||||||
data.parser = [](const std::string & input) {
|
data.parser = [](const std::string & input) {
|
||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
|
@ -344,10 +344,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) {
|
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
|
||||||
auto builtin_tools = json::array();
|
auto builtin_tools = json::array();
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
|
@ -380,7 +380,7 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
|
||||||
};
|
};
|
||||||
|
|
||||||
auto has_function = false;
|
auto has_function = false;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
@ -410,12 +410,12 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||||
{"tools_in_user_message", false},
|
{"tools_in_user_message", false},
|
||||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||||
});
|
});
|
||||||
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
|
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
|
||||||
data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||||
static std::regex close_regex("\\}");
|
static std::regex close_regex("\\}");
|
||||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||||
|
@ -447,17 +447,17 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||||
};
|
};
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
@ -466,27 +466,27 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
||||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||||
});
|
});
|
||||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
|
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "deepseek r1 tool calls";
|
data.format = "deepseek r1 tool calls";
|
||||||
data.parser = [params](const std::string & input) {
|
data.parser = [inputs](const std::string & input) {
|
||||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||||
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||||
};
|
};
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
fprintf(stderr, "%s\n", __func__);
|
fprintf(stderr, "%s\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
schemas.push_back({
|
schemas.push_back({
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -505,7 +505,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
||||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
{"minItems", 1},
|
{"minItems", 1},
|
||||||
};
|
};
|
||||||
if (!params.parallel_tool_calls) {
|
if (!inputs.parallel_tool_calls) {
|
||||||
schema["maxItems"] = 1;
|
schema["maxItems"] = 1;
|
||||||
}
|
}
|
||||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||||
|
@ -519,24 +519,24 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
||||||
data.parser = no_op_text_parser;
|
data.parser = no_op_text_parser;
|
||||||
data.format = "firefunction v2 text-only";
|
data.format = "firefunction v2 text-only";
|
||||||
}
|
}
|
||||||
data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, {
|
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||||
{"functions", json(params.tools.empty() ? "" : params.tools.dump(2))},
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||||
}, /* adjust_inputs= */ false);
|
}, /* adjust_inputs= */ false);
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
|
|
||||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> first_tool_rules;
|
std::vector<std::string> first_tool_rules;
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
@ -547,7 +547,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||||
});
|
});
|
||||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||||
if (params.parallel_tool_calls) {
|
if (inputs.parallel_tool_calls) {
|
||||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||||
} else {
|
} else {
|
||||||
|
@ -560,12 +560,12 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
data.format = "functionary v3.2 content-only";
|
data.format = "functionary v3.2 content-only";
|
||||||
}
|
}
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.parser = [params](const std::string & input) {
|
data.parser = [inputs](const std::string & input) {
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
|
|
||||||
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||||
if (res.content.find("all\n") == 0) {
|
if (res.content.find("all\n") == 0) {
|
||||||
res.content = res.content.substr(4);
|
res.content = res.content.substr(4);
|
||||||
}
|
}
|
||||||
|
@ -574,17 +574,17 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
json tools = params.tools.is_null() ? params.tools : json::array();
|
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||||
std::string python_code_argument_name;
|
std::string python_code_argument_name;
|
||||||
auto has_raw_python = false;
|
auto has_raw_python = false;
|
||||||
|
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
const auto & parameters = function["parameters"];
|
const auto & parameters = function["parameters"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
|
@ -618,13 +618,13 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
|
||||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
}
|
}
|
||||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "functionary v3.1 llama 3.1 tool calls";
|
data.format = "functionary v3.1 llama 3.1 tool calls";
|
||||||
data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
@ -644,18 +644,18 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
|
||||||
}
|
}
|
||||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
static std::regex close_regex(R"(</function>)");
|
static std::regex close_regex(R"(</function>)");
|
||||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||||
};
|
};
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
@ -670,11 +670,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "hermes 2 pro tool calls";
|
data.format = "hermes 2 pro tool calls";
|
||||||
data.parser = [&](const std::string & input) -> common_chat_msg {
|
data.parser = [&](const std::string & input) -> common_chat_msg {
|
||||||
try {
|
try {
|
||||||
|
@ -733,60 +733,60 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_data data;
|
common_chat_params data;
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "content-only";
|
data.format = "content-only";
|
||||||
data.parser = no_op_text_parser;
|
data.parser = no_op_text_parser;
|
||||||
data.grammar_lazy = false;
|
data.grammar_lazy = false;
|
||||||
if (!params.json_schema.is_null()) {
|
if (!inputs.json_schema.is_null()) {
|
||||||
if (!params.grammar.empty()) {
|
if (!inputs.grammar.empty()) {
|
||||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||||
}
|
}
|
||||||
data.grammar = json_schema_to_grammar(params.json_schema);
|
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||||
} else {
|
} else {
|
||||||
data.grammar = params.grammar.empty();
|
data.grammar = inputs.grammar.empty();
|
||||||
}
|
}
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
|
auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none";
|
||||||
if (has_tools && !params.grammar.empty()) {
|
if (has_tools && !inputs.grammar.empty()) {
|
||||||
throw std::runtime_error("Cannot specify grammar with tools");
|
throw std::runtime_error("Cannot specify grammar with tools");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & src = tmpl.source();
|
const auto & src = tmpl.source();
|
||||||
if (src.find(">>>all") != std::string::npos) {
|
if (src.find(">>>all") != std::string::npos) {
|
||||||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find(" functools[") != std::string::npos) {
|
if (src.find(" functools[") != std::string::npos) {
|
||||||
// Firefunction v2 requires datetime and functions in the context
|
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||||
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
|
return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_tools) {
|
if (has_tools) {
|
||||||
return common_chat_init_without_tools(tmpl, params);
|
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src.find("<tool_call>") != std::string::npos) {
|
if (src.find("<tool_call>") != std::string::npos) {
|
||||||
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>") != std::string::npos
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
&& src.find("<function=") != std::string::npos) {
|
&& src.find("<function=") != std::string::npos) {
|
||||||
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
return common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||||
return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||||
}
|
}
|
||||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||||
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
|
return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
|
return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs);
|
||||||
}
|
}
|
||||||
return common_chat_init_generic_tool_call(tmpl, params);
|
return common_chat_params_init_generic_tool_call(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,5 @@
|
||||||
/*
|
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||||
Copyright 2024 Google LLC
|
|
||||||
|
|
||||||
Use of this source code is governed by an MIT-style
|
|
||||||
license that can be found in the LICENSE file or at
|
|
||||||
https://opensource.org/licenses/MIT.
|
|
||||||
*/
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
@ -16,7 +10,7 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
struct common_chat_params {
|
struct common_chat_inputs {
|
||||||
json messages;
|
json messages;
|
||||||
json tools;
|
json tools;
|
||||||
json tool_choice;
|
json tool_choice;
|
||||||
|
@ -29,7 +23,7 @@ struct common_chat_params {
|
||||||
|
|
||||||
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
|
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
|
||||||
|
|
||||||
struct common_chat_data {
|
struct common_chat_params {
|
||||||
json prompt;
|
json prompt;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
std::vector<common_grammar_trigger> grammar_triggers;
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
|
@ -39,4 +33,4 @@ struct common_chat_data {
|
||||||
bool grammar_lazy = false;
|
bool grammar_lazy = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
|
@ -12,7 +12,7 @@
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "chat-handler.hpp"
|
#include "chat.hpp"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
try {
|
try {
|
||||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||||
common_chat_params params;
|
common_chat_inputs params;
|
||||||
params.messages = json::array({{
|
params.messages = json::array({{
|
||||||
{"role", "user"},
|
{"role", "user"},
|
||||||
{"content", "test"},
|
{"content", "test"},
|
||||||
}});
|
}});
|
||||||
common_chat_init(chat_template, params);
|
common_chat_params_init(chat_template, params);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||||
|
@ -1803,10 +1803,10 @@ std::string common_chat_apply_template(
|
||||||
for (const auto & msg : msgs) {
|
for (const auto & msg : msgs) {
|
||||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||||
}
|
}
|
||||||
common_chat_params params;
|
common_chat_inputs params;
|
||||||
params.messages = messages;
|
params.messages = messages;
|
||||||
params.add_generation_prompt = add_ass;
|
params.add_generation_prompt = add_ass;
|
||||||
auto data = common_chat_init(tmpl, params);
|
auto data = common_chat_params_init(tmpl, params);
|
||||||
return data.prompt;
|
return data.prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1824,16 +1824,16 @@ struct server_context {
|
||||||
|
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
auto templates = common_chat_templates_from_model(model, "");
|
auto templates = common_chat_templates_from_model(model, "");
|
||||||
common_chat_params params;
|
common_chat_inputs params;
|
||||||
params.messages = json::array({{
|
params.messages = json::array({{
|
||||||
{"role", "user"},
|
{"role", "user"},
|
||||||
{"content", "test"},
|
{"content", "test"},
|
||||||
}});
|
}});
|
||||||
GGML_ASSERT(templates.template_default);
|
GGML_ASSERT(templates.template_default);
|
||||||
try {
|
try {
|
||||||
common_chat_init(*templates.template_default, params);
|
common_chat_params_init(*templates.template_default, params);
|
||||||
if (templates.template_tool_use) {
|
if (templates.template_tool_use) {
|
||||||
common_chat_init(*templates.template_tool_use, params);
|
common_chat_params_init(*templates.template_tool_use, params);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
|
@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
common_chat_data chat_data;
|
common_chat_params chat_data;
|
||||||
bool add_special = false;
|
bool add_special = false;
|
||||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||||
chat_data = common_chat_init(*tmpl, {
|
chat_data = common_chat_params_init(*tmpl, {
|
||||||
/* .messages = */ json_value(data, "messages", json::array()),
|
/* .messages = */ json_value(data, "messages", json::array()),
|
||||||
/* .tools = */ json_value(data, "tools", json()),
|
/* .tools = */ json_value(data, "tools", json()),
|
||||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "minja.hpp"
|
#include "minja.hpp"
|
||||||
#include "chat-handler.hpp"
|
#include "chat.hpp"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
|
@ -93,7 +93,7 @@ if (NOT WIN32)
|
||||||
llama_target_and_test(test-grammar-parser.cpp)
|
llama_target_and_test(test-grammar-parser.cpp)
|
||||||
llama_target_and_test(test-grammar-integration.cpp)
|
llama_target_and_test(test-grammar-integration.cpp)
|
||||||
llama_target_and_test(test-llama-grammar.cpp)
|
llama_target_and_test(test-llama-grammar.cpp)
|
||||||
llama_target_and_test(test-chat-handler.cpp)
|
llama_target_and_test(test-chat.cpp)
|
||||||
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
|
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
|
||||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||||
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#include "chat-handler.hpp"
|
#include "chat.hpp"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
|
@ -169,15 +169,15 @@ struct delta_data {
|
||||||
};
|
};
|
||||||
|
|
||||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
common_chat_params params;
|
common_chat_inputs params;
|
||||||
params.parallel_tool_calls = true;
|
params.parallel_tool_calls = true;
|
||||||
params.messages = json::array();
|
params.messages = json::array();
|
||||||
params.messages.push_back(user_message);
|
params.messages.push_back(user_message);
|
||||||
params.tools = tools;
|
params.tools = tools;
|
||||||
auto prefix_data = common_chat_init(tmpl, params);
|
auto prefix_data = common_chat_params_init(tmpl, params);
|
||||||
params.messages.push_back(delta_message);
|
params.messages.push_back(delta_message);
|
||||||
params.add_generation_prompt = false;
|
params.add_generation_prompt = false;
|
||||||
auto full_data = common_chat_init(tmpl, params);
|
auto full_data = common_chat_params_init(tmpl, params);
|
||||||
|
|
||||||
std::string prefix = prefix_data.prompt;
|
std::string prefix = prefix_data.prompt;
|
||||||
std::string full = full_data.prompt;
|
std::string full = full_data.prompt;
|
||||||
|
@ -220,7 +220,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const auto & tool_choice : json({"auto", "required"})) {
|
for (const auto & tool_choice : json({"auto", "required"})) {
|
||||||
common_chat_params params;
|
common_chat_inputs params;
|
||||||
params.tool_choice = tool_choice;
|
params.tool_choice = tool_choice;
|
||||||
params.parallel_tool_calls = true;
|
params.parallel_tool_calls = true;
|
||||||
params.messages = json {user_message, test_message};
|
params.messages = json {user_message, test_message};
|
||||||
|
@ -301,15 +301,15 @@ static void test_template_output_parsers() {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
common_chat_params no_tools_params;
|
common_chat_inputs no_tools_params;
|
||||||
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||||
|
|
||||||
common_chat_params tools_params = no_tools_params;
|
common_chat_inputs tools_params = no_tools_params;
|
||||||
tools_params.tools = json::array();
|
tools_params.tools = json::array();
|
||||||
tools_params.tools.push_back(special_function_tool);
|
tools_params.tools.push_back(special_function_tool);
|
||||||
|
|
||||||
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
|
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||||
auto data = common_chat_init(tmpl, params);
|
auto data = common_chat_params_init(tmpl, params);
|
||||||
return data.format;
|
return data.format;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -322,7 +322,7 @@ static void test_template_output_parsers() {
|
||||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
|
||||||
"{\n"
|
"{\n"
|
||||||
" \"response\": \"Hello, world!\"\n"
|
" \"response\": \"Hello, world!\"\n"
|
||||||
"}"));
|
"}"));
|
Loading…
Add table
Add a link
Reference in a new issue