From e41b6ceee9f0b9bbae28f5c608dff3e3f6fb4864 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 2 May 2024 04:54:58 +0100 Subject: [PATCH] server: update tool calling, introduce system prompt for json schema --- common/json-schema-to-grammar.cpp | 99 ++++++++++++++++++++++--------- common/json-schema-to-grammar.h | 2 +- examples/server/utils.hpp | 68 ++++++++++++--------- 3 files changed, 110 insertions(+), 59 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 393153b40..4c09587a0 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -200,28 +200,28 @@ static std::string format_literal(const std::string & literal) { return "\"" + escaped + "\""; } - /* not_literal('a') -> '[^a]' not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' */ -static std::string not_literal(const std::string & literal, bool dotall = true) { - assert(literal.size() > 0); - std::stringstream out; - std::function recurse = [&](size_t i) { - const auto & c = literal[i]; - out << "[^" << c << "]"; - if (i < literal.size() - 1) { - out << " | " << format_literal(std::to_string(c)) << " ("; - recurse(i + 1); - out << ")?"; - } - }; - out << "("; - recurse(0); - out << ")" << (dotall ? DOTALL : DOT) << "*"; - return out.str(); -} +// static std::string not_literal(const std::string & literal, bool dotall = true) { +// assert(literal.size() > 0); +// std::stringstream out; +// std::function recurse = [&](size_t i) { +// const char & c = literal[i]; +// out << "[^" << c << "]"; +// out << " " << (dotall ? DOTALL : DOT) << "*"; +// if (i < literal.size() - 1) { +// out << " | " << format_literal(literal.substr(i, 1)) << " ("; +// recurse(i + 1); +// out << ")?"; +// } +// }; +// out << "("; +// recurse(0); +// out << ")"; +// return out.str(); +// } class SchemaConverter { @@ -625,17 +625,57 @@ public: visit_refs(schema); } +/* + reply ::= prefix tool-call* + + prefix ::= [^<] prefix + | "<" [^t] prefix + | "] prefix + | + +*/ + + std::string not_literal(const std::string & literal) { + auto rule_name = _find_rule_name("not" + literal, "!!!"); + std::stringstream out; + for (size_t i = 0, n = literal.size(); i < n; i++) { + out << " | "; + if (i > 0) { + out << format_literal(literal.substr(0, i)) << " "; + } + out << "[^" << literal[i] << "] " << rule_name.c_str(); + } + _rules[rule_name] = out.str(); + return rule_name; + } + + std::string _escape_name(const std::string & name) { + return regex_replace(name, INVALID_RULE_CHARS_RE, "-"); + } + std::string _find_rule_name(const std::string & name, const std::string & rule) { + auto esc_name = _escape_name(name); + int i = 0; + while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { + i++; + } + return esc_name + std::to_string(i); + } std::string add_rule(const std::string & name, const std::string & rule) { - std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-"); + auto esc_name = _escape_name(name); if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) { _rules[esc_name] = rule; return esc_name; } else { - int i = 0; - while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { - i++; - } - std::string key = esc_name + std::to_string(i); + auto key = _find_rule_name(esc_name, rule); _rules[key] = rule; return key; } @@ -789,7 +829,7 @@ std::string json_schema_to_grammar(const json & schema) { return converter.format_grammar(); } -std::string tool_call_grammar(const json & tools) { +std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) { SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); std::vector tool_rules; @@ -814,12 +854,13 @@ std::string tool_call_grammar(const json & tools) { converter.add_rule( "root", - not_literal("") + " | " - + converter.add_rule( + converter.not_literal("") + " " + + converter.add_rule( "tool_call", - "\"\" " + "\"\" (" + join(tool_rules.begin(), tool_rules.end(), " | ") - + " \"\"")); + + ") \"\"" + ) + (allow_parallel_calls ? "*" : "?")); converter.check_errors(); return converter.format_grammar(); diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 5594051b1..024825151 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -1,5 +1,5 @@ #pragma once #include "json.hpp" -std::string tool_call_grammar(const nlohmann::ordered_json & tools); +std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false); std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 336307a56..3bcfb2252 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -123,7 +123,7 @@ inline bool verify_custom_template(const std::string & tmpl) { } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages, const std::string & tools_tag) { +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages, const std::string & extra_system_message) { size_t alloc_size = 0; // vector holding all allocated string to be passed to llama_chat_apply_template std::vector str(messages.size() * 2); @@ -138,18 +138,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat[i].content = str[i*2 + 1].c_str(); } - if (!tools_tag.empty()) { - alloc_size += tools_tag.size(); - if (chat.empty()) { - str.resize(2); - str[0] = "user"; - str[1] = tools_tag; - chat.push_back({str[0].c_str(), str[1].c_str()}); - } else { - auto & content = str[str.size() - 1]; - content += tools_tag; - chat[chat.size() - 1].content = content.c_str(); - } + if (!extra_system_message.empty()) { + alloc_size += extra_system_message.size(); + + llama_chat_message msg { "system", extra_system_message.c_str() }; + chat.insert(chat.begin(), msg); + // chat.push_back(msg); } const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); @@ -387,22 +381,7 @@ static json oaicompat_completion_params_parse( llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_p"] = json_value(body, "top_p", 1.0); - std::string tools_tag; - if (body.contains("tools") && body["tools"].is_array()) { - const auto & tools = body["tools"]; - llama_params["grammar"] = tool_call_grammar(tools); - tools_tag = (std::stringstream() << "\n\n" << tools.dump(2) << "").str(); - } - - // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body["messages"], tools_tag); - - // Handle "stop" field - if (body.contains("stop") && body["stop"].is_string()) { - llama_params["stop"] = json::array({body["stop"].get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } + std::string extra_system_message; // Handle "response_format" field if (body.contains("response_format")) { @@ -410,9 +389,40 @@ static json oaicompat_completion_params_parse( std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + extra_system_message = (std::stringstream() + << "You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n\n" + << llama_params["json_schema"].dump().c_str() + << "\n" + ).str(); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } + } else if (body.contains("tools") && body["tools"].is_array()) { + const auto & tools = body["tools"]; + llama_params["grammar"] = tool_call_grammar(tools); + + extra_system_message = (std::stringstream() + << "You are a function calling AI model. You are provided with function signatures within XML tags. " + << "You may call one or more functions to assist with the user query. " + << "Don't make assumptions about what values to plug into functions. " + << "Here are the available tools: " + << tools.dump().c_str() + << "\n" + << "For each function call return a json object with function name and arguments within XML tags as follows:" + << "" + << "{\"arguments\": , \"name\": }" + << "" + ).str(); + } + + // Apply chat template to the list of messages + llama_params["prompt"] = format_chat(model, chat_template, body["messages"], extra_system_message); + + // Handle "stop" field + if (body.contains("stop") && body["stop"].is_string()) { + llama_params["stop"] = json::array({body["stop"].get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); } // Handle "n" field