Simplify parser defs (incremental parsing for streaming will need more thinking)
This commit is contained in:
parent
ec4aeaf18a
commit
b5a74d1a24
4 changed files with 34 additions and 97 deletions
|
@ -152,50 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
|
|||
return result;
|
||||
}
|
||||
|
||||
class text_chat_parser : public common_chat_parser {
|
||||
public:
|
||||
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||
return parse_final(input);
|
||||
}
|
||||
|
||||
common_chat_msg parse_final(const std::string & input) override {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<common_chat_parser> clone() const override {
|
||||
return std::make_unique<text_chat_parser>();
|
||||
}
|
||||
};
|
||||
|
||||
class monolithic_chat_parser : public common_chat_parser {
|
||||
|
||||
std::string input_buffer_;
|
||||
std::function<common_chat_msg(const std::string & input)> parse_final_;
|
||||
|
||||
public:
|
||||
monolithic_chat_parser(const std::function<common_chat_msg(const std::string & input)> & parse_final) : parse_final_(parse_final) {}
|
||||
|
||||
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||
input_buffer_ += input;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
common_chat_msg parse_final(const std::string & input) override {
|
||||
input_buffer_ += input;
|
||||
auto out = parse_final_(input_buffer_);
|
||||
input_buffer_.clear();
|
||||
return out;
|
||||
}
|
||||
|
||||
std::unique_ptr<common_chat_parser> clone() const override {
|
||||
return std::make_unique<monolithic_chat_parser>(parse_final_);
|
||||
}
|
||||
};
|
||||
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
|
@ -289,7 +245,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
|
||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "generic tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
||||
data.parser = [&](const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
|
@ -312,7 +268,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -355,9 +311,9 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "mistral nemo tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -441,7 +397,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = "llama 3.1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
data.parser = [params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
|
@ -472,7 +428,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
};
|
||||
}
|
||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -505,12 +461,12 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
|||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
|
||||
data.format = "llama 3.2 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
|
||||
data.parser = [params](const std::string & input) {
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
return res;
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -532,12 +488,12 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
|||
}, grammar_options);
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "deepseek r1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
|
||||
data.parser = [params](const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -573,9 +529,9 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
|||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "firefunction v2 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -610,7 +566,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "functionary v3.2 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
|
||||
data.parser = [params](const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
|
@ -619,7 +575,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
res.content = res.content.substr(4);
|
||||
}
|
||||
return res;
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -674,7 +630,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
|
|||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "functionary v3.1 llama 3.1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
|
@ -695,7 +651,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
|
|||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -726,7 +682,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "hermes 2 pro tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
||||
data.parser = [&](const std::string & input) -> common_chat_msg {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
|
@ -779,7 +735,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
});
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -787,7 +743,13 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
|
|||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "content-only";
|
||||
data.parser = std::make_unique<text_chat_parser>();
|
||||
data.parser = [](const std::string & input) -> common_chat_msg {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
};
|
||||
data.grammar_lazy = false;
|
||||
if (!params.json_schema.is_null()) {
|
||||
if (!params.grammar.empty()) {
|
||||
|
|
|
@ -27,21 +27,14 @@ struct common_chat_params {
|
|||
bool add_generation_prompt = true;
|
||||
};
|
||||
|
||||
class common_chat_parser {
|
||||
public:
|
||||
virtual ~common_chat_parser() = default;
|
||||
|
||||
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
|
||||
virtual common_chat_msg parse_final(const std::string & input) = 0;
|
||||
virtual std::unique_ptr<common_chat_parser> clone() const = 0;
|
||||
};
|
||||
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
|
||||
|
||||
struct common_chat_data {
|
||||
json prompt;
|
||||
std::string grammar;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::unique_ptr<class common_chat_parser> parser;
|
||||
std::vector<std::string> additional_stops;// std::unique_ptr<class common_chat_parser> parser;
|
||||
common_chat_parser parser;
|
||||
std::string format; // For debugging and testing.
|
||||
bool grammar_lazy = false;
|
||||
};
|
||||
|
|
|
@ -117,7 +117,7 @@ struct slot_params {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::shared_ptr<common_chat_parser> chat_parser;
|
||||
common_chat_parser chat_parser;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
|
@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_chat_msg;
|
||||
std::shared_ptr<common_chat_parser> chat_parser;
|
||||
|
||||
virtual int get_index() override {
|
||||
|
@ -2220,16 +2219,6 @@ struct server_context {
|
|||
}
|
||||
|
||||
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
||||
common_chat_msg msg;
|
||||
if (slot.params.chat_parser) {
|
||||
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
|
||||
msg = *opt_msg;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
msg.content = tkn.text_to_send;
|
||||
}
|
||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||
|
||||
res->id = slot.id_task;
|
||||
|
@ -2245,7 +2234,6 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_msg = msg;
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -2286,18 +2274,14 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
if (!slot.params.chat_parser) {
|
||||
if (slot.params.chat_parser) {
|
||||
res->oaicompat_chat_msg = slot.params.chat_parser(slot.generated_text);
|
||||
} else {
|
||||
res->oaicompat_chat_msg = {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ slot.generated_text,
|
||||
/* .tool_calls = */ {}
|
||||
};
|
||||
} else if (slot.stop == STOP_TYPE_LIMIT) {
|
||||
if (auto opt_msg = slot.params.chat_parser->parse_partial(slot.generated_text)) {
|
||||
res->oaicompat_chat_msg = *opt_msg;
|
||||
}
|
||||
} else {
|
||||
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
|
||||
}
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -3835,9 +3819,7 @@ int main(int argc, char ** argv) {
|
|||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
task.params.antiprompt = chat_data.additional_stops;
|
||||
if (chat_data.parser) {
|
||||
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
|
||||
}
|
||||
task.params.chat_parser = chat_data.parser;
|
||||
if (task.params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
|
||||
}
|
||||
|
|
|
@ -397,7 +397,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||
|
||||
const auto msg = chat_data.parser->parse_final(full_delta);
|
||||
const auto msg = chat_data.parser(full_delta);
|
||||
assert_msg_equals(expected_msg, msg);
|
||||
|
||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue