Simplify parser defs (incremental parsing for streaming will need more thinking)

This commit is contained in:
ochafik 2025-01-28 10:48:11 +00:00
parent ec4aeaf18a
commit b5a74d1a24
4 changed files with 34 additions and 97 deletions

View file

@ -152,50 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
return result;
}
class text_chat_parser : public common_chat_parser {
public:
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
return parse_final(input);
}
common_chat_msg parse_final(const std::string & input) override {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
std::unique_ptr<common_chat_parser> clone() const override {
return std::make_unique<text_chat_parser>();
}
};
class monolithic_chat_parser : public common_chat_parser {
std::string input_buffer_;
std::function<common_chat_msg(const std::string & input)> parse_final_;
public:
monolithic_chat_parser(const std::function<common_chat_msg(const std::string & input)> & parse_final) : parse_final_(parse_final) {}
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
input_buffer_ += input;
return std::nullopt;
}
common_chat_msg parse_final(const std::string & input) override {
input_buffer_ += input;
auto out = parse_final_(input_buffer_);
input_buffer_.clear();
return out;
}
std::unique_ptr<common_chat_parser> clone() const override {
return std::make_unique<monolithic_chat_parser>(parse_final_);
}
};
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) {
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
@ -289,7 +245,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "generic tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
data.parser = [&](const std::string & input) {
json data = json::parse(input);
common_chat_msg result;
result.role = "assistant";
@ -312,7 +268,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
});
};
return data;
}
@ -355,9 +311,9 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "mistral nemo tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
});
};
return data;
}
@ -441,7 +397,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
});
data.format = "llama 3.1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
data.parser = [params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
@ -472,7 +428,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
};
}
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
});
};
return data;
}
@ -505,12 +461,12 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
data.format = "llama 3.2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
data.parser = [params](const std::string & input) {
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
return res;
});
};
return data;
}
@ -532,12 +488,12 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "deepseek r1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
data.parser = [params](const std::string & input) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
});
};
return data;
}
@ -573,9 +529,9 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "firefunction v2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
});
};
return data;
}
@ -610,7 +566,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "functionary v3.2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
data.parser = [params](const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
@ -619,7 +575,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
res.content = res.content.substr(4);
}
return res;
});
};
return data;
}
@ -674,7 +630,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "functionary v3.1 llama 3.1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
@ -695,7 +651,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
});
};
return data;
}
@ -726,7 +682,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "hermes 2 pro tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
data.parser = [&](const std::string & input) -> common_chat_msg {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
@ -779,7 +735,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
/* .tool_calls = */ {},
};
}
});
};
return data;
}
@ -787,7 +743,13 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only";
data.parser = std::make_unique<text_chat_parser>();
data.parser = [](const std::string & input) -> common_chat_msg {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
};
data.grammar_lazy = false;
if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) {

View file

@ -27,21 +27,14 @@ struct common_chat_params {
bool add_generation_prompt = true;
};
class common_chat_parser {
public:
virtual ~common_chat_parser() = default;
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
virtual common_chat_msg parse_final(const std::string & input) = 0;
virtual std::unique_ptr<common_chat_parser> clone() const = 0;
};
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
struct common_chat_data {
json prompt;
std::string grammar;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> additional_stops;
std::unique_ptr<class common_chat_parser> parser;
std::vector<std::string> additional_stops;// std::unique_ptr<class common_chat_parser> parser;
common_chat_parser parser;
std::string format; // For debugging and testing.
bool grammar_lazy = false;
};

View file

@ -117,7 +117,7 @@ struct slot_params {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::shared_ptr<common_chat_parser> chat_parser;
common_chat_parser chat_parser;
json to_json() const {
std::vector<std::string> samplers;
@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_chat_msg;
std::shared_ptr<common_chat_parser> chat_parser;
virtual int get_index() override {
@ -2220,16 +2219,6 @@ struct server_context {
}
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
common_chat_msg msg;
if (slot.params.chat_parser) {
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
msg = *opt_msg;
} else {
return;
}
} else {
msg.content = tkn.text_to_send;
}
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->id = slot.id_task;
@ -2245,7 +2234,6 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_msg = msg;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -2286,18 +2274,14 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
if (!slot.params.chat_parser) {
if (slot.params.chat_parser) {
res->oaicompat_chat_msg = slot.params.chat_parser(slot.generated_text);
} else {
res->oaicompat_chat_msg = {
/* .role = */ "assistant",
/* .content = */ slot.generated_text,
/* .tool_calls = */ {}
};
} else if (slot.stop == STOP_TYPE_LIMIT) {
if (auto opt_msg = slot.params.chat_parser->parse_partial(slot.generated_text)) {
res->oaicompat_chat_msg = *opt_msg;
}
} else {
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
}
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -3835,9 +3819,7 @@ int main(int argc, char ** argv) {
task.params.sampling.grammar_trigger_words.push_back(trigger);
}
task.params.antiprompt = chat_data.additional_stops;
if (chat_data.parser) {
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
}
task.params.chat_parser = chat_data.parser;
if (task.params.sampling.grammar_lazy) {
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
}

View file

@ -397,7 +397,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
const auto msg = chat_data.parser->parse_final(full_delta);
const auto msg = chat_data.parser(full_delta);
assert_msg_equals(expected_msg, msg);
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {