server: tool call grammar-constraints
fix
This commit is contained in:
parent
312e20b54a
commit
ca1a640da2
4 changed files with 127 additions and 42 deletions
|
@ -11,6 +11,9 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
const char * DOTALL = "[\\U00000000-\\U0010FFFF]";
|
||||||
|
const char * DOT = "[^\\x0A\\x0D]";
|
||||||
|
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
static std::string join(Iterator begin, Iterator end, const std::string & separator);
|
static std::string join(Iterator begin, Iterator end, const std::string & separator);
|
||||||
|
|
||||||
|
@ -198,6 +201,29 @@ static std::string format_literal(const std::string & literal) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
not_literal('a') -> '[^a]'
|
||||||
|
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||||
|
*/
|
||||||
|
static std::string not_literal(const std::string & literal, bool dotall = true) {
|
||||||
|
assert(literal.size() > 0);
|
||||||
|
std::stringstream out;
|
||||||
|
std::function<void(int)> recurse = [&](size_t i) {
|
||||||
|
const auto & c = literal[i];
|
||||||
|
out << "[^" << c << "]";
|
||||||
|
if (i < literal.size() - 1) {
|
||||||
|
out << " | " << format_literal(std::to_string(c)) << " (";
|
||||||
|
recurse(i + 1);
|
||||||
|
out << ")?";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
out << "(";
|
||||||
|
recurse(0);
|
||||||
|
out << ")" << (dotall ? DOTALL : DOT) << "*";
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SchemaConverter {
|
class SchemaConverter {
|
||||||
private:
|
private:
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
|
@ -208,22 +234,6 @@ private:
|
||||||
std::vector<std::string> _errors;
|
std::vector<std::string> _errors;
|
||||||
std::vector<std::string> _warnings;
|
std::vector<std::string> _warnings;
|
||||||
|
|
||||||
std::string _add_rule(const std::string & name, const std::string & rule) {
|
|
||||||
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
|
|
||||||
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
|
||||||
_rules[esc_name] = rule;
|
|
||||||
return esc_name;
|
|
||||||
} else {
|
|
||||||
int i = 0;
|
|
||||||
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
std::string key = esc_name + std::to_string(i);
|
|
||||||
_rules[key] = rule;
|
|
||||||
return key;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
|
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
|
||||||
std::vector<std::string> rules;
|
std::vector<std::string> rules;
|
||||||
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
||||||
|
@ -256,11 +266,11 @@ private:
|
||||||
auto get_dot = [&]() {
|
auto get_dot = [&]() {
|
||||||
std::string rule;
|
std::string rule;
|
||||||
if (_dotall) {
|
if (_dotall) {
|
||||||
rule = "[\\U00000000-\\U0010FFFF]";
|
rule = DOTALL;
|
||||||
} else {
|
} else {
|
||||||
rule = "[^\\x0A\\x0D]";
|
rule = DOT;
|
||||||
}
|
}
|
||||||
return _add_rule("dot", rule);
|
return add_rule("dot", rule);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Joins the sequence, merging consecutive literals together.
|
// Joins the sequence, merging consecutive literals together.
|
||||||
|
@ -377,7 +387,7 @@ private:
|
||||||
if (!sub_is_literal) {
|
if (!sub_is_literal) {
|
||||||
std::string & sub_id = sub_rule_ids[sub];
|
std::string & sub_id = sub_rule_ids[sub];
|
||||||
if (sub_id.empty()) {
|
if (sub_id.empty()) {
|
||||||
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
|
sub_id = add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
|
||||||
}
|
}
|
||||||
sub = sub_id;
|
sub = sub_id;
|
||||||
}
|
}
|
||||||
|
@ -423,7 +433,7 @@ private:
|
||||||
}
|
}
|
||||||
return join_seq();
|
return join_seq();
|
||||||
};
|
};
|
||||||
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
|
return add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string _resolve_ref(const std::string & ref) {
|
std::string _resolve_ref(const std::string & ref) {
|
||||||
|
@ -451,7 +461,7 @@ private:
|
||||||
const auto &prop_schema = kv.second;
|
const auto &prop_schema = kv.second;
|
||||||
|
|
||||||
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
|
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
|
||||||
prop_kv_rule_names[prop_name] = _add_rule(
|
prop_kv_rule_names[prop_name] = add_rule(
|
||||||
name + (name.empty() ? "" : "-") + prop_name + "-kv",
|
name + (name.empty() ? "" : "-") + prop_name + "-kv",
|
||||||
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
|
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
|
||||||
);
|
);
|
||||||
|
@ -464,7 +474,7 @@ private:
|
||||||
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
|
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
|
||||||
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
|
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
|
||||||
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
|
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
|
||||||
std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
|
std::string kv_rule = add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
|
||||||
prop_kv_rule_names["*"] = kv_rule;
|
prop_kv_rule_names["*"] = kv_rule;
|
||||||
optional_props.push_back("*");
|
optional_props.push_back("*");
|
||||||
}
|
}
|
||||||
|
@ -491,7 +501,7 @@ private:
|
||||||
std::string k = ks[0];
|
std::string k = ks[0];
|
||||||
std::string kv_rule_name = prop_kv_rule_names[k];
|
std::string kv_rule_name = prop_kv_rule_names[k];
|
||||||
if (k == "*") {
|
if (k == "*") {
|
||||||
res = _add_rule(
|
res = add_rule(
|
||||||
name + (name.empty() ? "" : "-") + "additional-kvs",
|
name + (name.empty() ? "" : "-") + "additional-kvs",
|
||||||
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
|
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
|
||||||
);
|
);
|
||||||
|
@ -501,7 +511,7 @@ private:
|
||||||
res = kv_rule_name;
|
res = kv_rule_name;
|
||||||
}
|
}
|
||||||
if (ks.size() > 1) {
|
if (ks.size() > 1) {
|
||||||
res += " " + _add_rule(
|
res += " " + add_rule(
|
||||||
name + (name.empty() ? "" : "-") + k + "-rest",
|
name + (name.empty() ? "" : "-") + k + "-rest",
|
||||||
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
|
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
|
||||||
);
|
);
|
||||||
|
@ -527,7 +537,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
|
std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
|
||||||
auto n = _add_rule(name, rule.content);
|
auto n = add_rule(name, rule.content);
|
||||||
for (const auto & dep : rule.deps) {
|
for (const auto & dep : rule.deps) {
|
||||||
BuiltinRule dep_rule;
|
BuiltinRule dep_rule;
|
||||||
auto it = PRIMITIVE_RULES.find(dep);
|
auto it = PRIMITIVE_RULES.find(dep);
|
||||||
|
@ -615,6 +625,22 @@ public:
|
||||||
visit_refs(schema);
|
visit_refs(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string add_rule(const std::string & name, const std::string & rule) {
|
||||||
|
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
|
||||||
|
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
||||||
|
_rules[esc_name] = rule;
|
||||||
|
return esc_name;
|
||||||
|
} else {
|
||||||
|
int i = 0;
|
||||||
|
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
std::string key = esc_name + std::to_string(i);
|
||||||
|
_rules[key] = rule;
|
||||||
|
return key;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string _generate_constant_rule(const json & value) {
|
std::string _generate_constant_rule(const json & value) {
|
||||||
return format_literal(value.dump());
|
return format_literal(value.dump());
|
||||||
}
|
}
|
||||||
|
@ -625,24 +651,24 @@ public:
|
||||||
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
|
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
|
||||||
|
|
||||||
if (schema.contains("$ref")) {
|
if (schema.contains("$ref")) {
|
||||||
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
|
return add_rule(rule_name, _resolve_ref(schema["$ref"]));
|
||||||
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
|
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
|
||||||
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
|
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
|
||||||
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
|
return add_rule(rule_name, _generate_union_rule(name, alt_schemas));
|
||||||
} else if (schema_type.is_array()) {
|
} else if (schema_type.is_array()) {
|
||||||
std::vector<json> schema_types;
|
std::vector<json> schema_types;
|
||||||
for (const auto & t : schema_type) {
|
for (const auto & t : schema_type) {
|
||||||
schema_types.push_back({{"type", t}});
|
schema_types.push_back({{"type", t}});
|
||||||
}
|
}
|
||||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
return add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||||
} else if (schema.contains("const")) {
|
} else if (schema.contains("const")) {
|
||||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
return add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
||||||
} else if (schema.contains("enum")) {
|
} else if (schema.contains("enum")) {
|
||||||
std::vector<std::string> enum_values;
|
std::vector<std::string> enum_values;
|
||||||
for (const auto & v : schema["enum"]) {
|
for (const auto & v : schema["enum"]) {
|
||||||
enum_values.push_back(_generate_constant_rule(v));
|
enum_values.push_back(_generate_constant_rule(v));
|
||||||
}
|
}
|
||||||
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
|
return add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
|
||||||
} else if ((schema_type.is_null() || schema_type == "object")
|
} else if ((schema_type.is_null() || schema_type == "object")
|
||||||
&& (schema.contains("properties") ||
|
&& (schema.contains("properties") ||
|
||||||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
||||||
|
@ -660,7 +686,7 @@ public:
|
||||||
properties.emplace_back(prop.key(), prop.value());
|
properties.emplace_back(prop.key(), prop.value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return _add_rule(rule_name,
|
return add_rule(rule_name,
|
||||||
_build_object_rule(
|
_build_object_rule(
|
||||||
properties, required, name,
|
properties, required, name,
|
||||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||||
|
@ -691,7 +717,7 @@ public:
|
||||||
add_component(t, true);
|
add_component(t, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
return add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||||
if (items.is_array()) {
|
if (items.is_array()) {
|
||||||
|
@ -703,14 +729,14 @@ public:
|
||||||
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
||||||
}
|
}
|
||||||
rule += " \"]\" space";
|
rule += " \"]\" space";
|
||||||
return _add_rule(rule_name, rule);
|
return add_rule(rule_name, rule);
|
||||||
} else {
|
} else {
|
||||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||||
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
|
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
|
||||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||||
|
|
||||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
return add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||||
}
|
}
|
||||||
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||||
return _visit_pattern(schema["pattern"], rule_name);
|
return _visit_pattern(schema["pattern"], rule_name);
|
||||||
|
@ -718,14 +744,14 @@ public:
|
||||||
return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
|
return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
|
||||||
} else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
|
} else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
|
||||||
auto prim_name = schema_format + "-string";
|
auto prim_name = schema_format + "-string";
|
||||||
return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
|
return add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
|
||||||
} else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
|
} else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
|
||||||
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
||||||
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
||||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
return add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||||
} else if (schema.empty() || schema_type == "object") {
|
} else if (schema.empty() || schema_type == "object") {
|
||||||
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
|
return add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
|
||||||
} else {
|
} else {
|
||||||
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
|
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
|
||||||
_errors.push_back("Unrecognized schema: " + schema.dump());
|
_errors.push_back("Unrecognized schema: " + schema.dump());
|
||||||
|
@ -762,3 +788,39 @@ std::string json_schema_to_grammar(const json & schema) {
|
||||||
converter.check_errors();
|
converter.check_errors();
|
||||||
return converter.format_grammar();
|
return converter.format_grammar();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string tool_call_grammar(const json & tools) {
|
||||||
|
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
||||||
|
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
std::string description = function["description"];
|
||||||
|
auto parameters_copy = function["parameters"];
|
||||||
|
converter.resolve_refs(parameters_copy, name);
|
||||||
|
|
||||||
|
tool_rules.push_back(converter.visit(json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"description", description},
|
||||||
|
{"properties", json {
|
||||||
|
{"name", json {{"const", name}}},
|
||||||
|
{"arguments", parameters_copy},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments"})},
|
||||||
|
}, name + "-tool-call"));
|
||||||
|
}
|
||||||
|
|
||||||
|
converter.add_rule(
|
||||||
|
"root",
|
||||||
|
not_literal("<tool_call>") + " | "
|
||||||
|
+ converter.add_rule(
|
||||||
|
"tool_call",
|
||||||
|
"\"<tool_call>\" "
|
||||||
|
+ join(tool_rules.begin(), tool_rules.end(), " | ")
|
||||||
|
+ " \"</tool_call>\""));
|
||||||
|
|
||||||
|
converter.check_errors();
|
||||||
|
return converter.format_grammar();
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
|
||||||
|
std::string tool_call_grammar(const nlohmann::ordered_json & tools);
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
||||||
|
|
|
@ -3031,7 +3031,7 @@ int main(int argc, char ** argv) {
|
||||||
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
|
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
|
||||||
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
|
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
|
||||||
|
|
||||||
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
|
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat, "");
|
||||||
|
|
||||||
LOG_INFO("chat template", {
|
LOG_INFO("chat template", {
|
||||||
{"chat_example", chat_example},
|
{"chat_example", chat_example},
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -122,7 +123,7 @@ inline bool verify_custom_template(const std::string & tmpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages, const std::string & tools_tag) {
|
||||||
size_t alloc_size = 0;
|
size_t alloc_size = 0;
|
||||||
// vector holding all allocated string to be passed to llama_chat_apply_template
|
// vector holding all allocated string to be passed to llama_chat_apply_template
|
||||||
std::vector<std::string> str(messages.size() * 2);
|
std::vector<std::string> str(messages.size() * 2);
|
||||||
|
@ -137,6 +138,20 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
chat[i].content = str[i*2 + 1].c_str();
|
chat[i].content = str[i*2 + 1].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!tools_tag.empty()) {
|
||||||
|
alloc_size += tools_tag.size();
|
||||||
|
if (chat.empty()) {
|
||||||
|
str.resize(2);
|
||||||
|
str[0] = "user";
|
||||||
|
str[1] = tools_tag;
|
||||||
|
chat.push_back({str[0].c_str(), str[1].c_str()});
|
||||||
|
} else {
|
||||||
|
auto & content = str[str.size() - 1];
|
||||||
|
content += tools_tag;
|
||||||
|
chat[chat.size() - 1].content = content.c_str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||||
std::vector<char> buf(alloc_size * 2);
|
std::vector<char> buf(alloc_size * 2);
|
||||||
|
|
||||||
|
@ -372,8 +387,15 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||||
llama_params["top_p"] = json_value(body, "top_p", 1.0);
|
llama_params["top_p"] = json_value(body, "top_p", 1.0);
|
||||||
|
|
||||||
|
std::string tools_tag;
|
||||||
|
if (body.contains("tools") && body["tools"].is_array()) {
|
||||||
|
const auto & tools = body["tools"];
|
||||||
|
llama_params["grammar"] = tool_call_grammar(tools);
|
||||||
|
tools_tag = (std::stringstream() << "\n\n<tools>" << tools.dump(2) << "</tools>").str();
|
||||||
|
}
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
llama_params["prompt"] = format_chat(model, chat_template, body["messages"], tools_tag);
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body["stop"].is_string()) {
|
if (body.contains("stop") && body["stop"].is_string()) {
|
||||||
|
@ -408,7 +430,7 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
|
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
||||||
for (auto & param : unsupported_params) {
|
for (auto & param : unsupported_params) {
|
||||||
if (body.contains(param)) {
|
if (body.contains(param)) {
|
||||||
throw std::runtime_error("Unsupported param: " + param);
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue