tool-calls
: basic Nemo support, default parallel to true if template mentions tool_call_id
This commit is contained in:
parent
fc80ad20ce
commit
3e12b9b38e
6 changed files with 227 additions and 77 deletions
|
@ -26,6 +26,7 @@ class chat_template {
|
||||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||||
bool _requires_object_arguments = false;
|
bool _requires_object_arguments = false;
|
||||||
bool _supports_system_role = true;
|
bool _supports_system_role = true;
|
||||||
|
bool _supports_parallel_tool_calls = false;
|
||||||
std::string _source;
|
std::string _source;
|
||||||
std::string _bos_token;
|
std::string _bos_token;
|
||||||
std::string _eos_token;
|
std::string _eos_token;
|
||||||
|
@ -40,6 +41,7 @@ class chat_template {
|
||||||
source.find("tool_call.arguments | items") != std::string::npos
|
source.find("tool_call.arguments | items") != std::string::npos
|
||||||
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
||||||
_supports_system_role = source.find("System role not supported") == std::string::npos;
|
_supports_system_role = source.find("System role not supported") == std::string::npos;
|
||||||
|
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
|
||||||
|
|
||||||
_template_root = minja::Parser::parse(_source, {
|
_template_root = minja::Parser::parse(_source, {
|
||||||
/* .trim_blocks = */ true,
|
/* .trim_blocks = */ true,
|
||||||
|
@ -50,6 +52,7 @@ class chat_template {
|
||||||
|
|
||||||
const std::string & source() const { return _source; }
|
const std::string & source() const { return _source; }
|
||||||
bool supports_tools() const { return _supports_tools; }
|
bool supports_tools() const { return _supports_tools; }
|
||||||
|
bool supports_parallel_tool_calls() const { return _supports_parallel_tool_calls; }
|
||||||
|
|
||||||
std::string apply(
|
std::string apply(
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
|
|
|
@ -14,6 +14,8 @@ using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
std::string llama_tool_call_style_name(llama_tool_call_style style) {
|
std::string llama_tool_call_style_name(llama_tool_call_style style) {
|
||||||
switch (style) {
|
switch (style) {
|
||||||
|
case llama_tool_call_style::None:
|
||||||
|
return "None";
|
||||||
case llama_tool_call_style::Generic:
|
case llama_tool_call_style::Generic:
|
||||||
return "Generic";
|
return "Generic";
|
||||||
case llama_tool_call_style::Llama31:
|
case llama_tool_call_style::Llama31:
|
||||||
|
@ -28,6 +30,8 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) {
|
||||||
return "Hermes2Pro";
|
return "Hermes2Pro";
|
||||||
case llama_tool_call_style::CommandRPlus:
|
case llama_tool_call_style::CommandRPlus:
|
||||||
return "CommandRPlus";
|
return "CommandRPlus";
|
||||||
|
case llama_tool_call_style::MistralNemo:
|
||||||
|
return "MistralNemo";
|
||||||
default:
|
default:
|
||||||
return "Unknown";
|
return "Unknown";
|
||||||
}
|
}
|
||||||
|
@ -51,6 +55,8 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template &
|
||||||
}
|
}
|
||||||
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||||
return CommandRPlus;
|
return CommandRPlus;
|
||||||
|
} else if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
|
return MistralNemo;
|
||||||
} else {
|
} else {
|
||||||
return Generic;
|
return Generic;
|
||||||
}
|
}
|
||||||
|
@ -146,7 +152,7 @@ static llama_tool_calls parse_json_tool_calls(const json & tools, const std::str
|
||||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||||
}
|
}
|
||||||
it = match.suffix().first;
|
it = match.suffix().first;
|
||||||
result.tool_calls.push_back({name, arguments.dump()});
|
result.tool_calls.push_back({name, arguments.dump(), /* id= */ ""});
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -176,6 +182,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
|
||||||
result.tool_calls.push_back({
|
result.tool_calls.push_back({
|
||||||
call["name"],
|
call["name"],
|
||||||
call["arguments"].dump(),
|
call["arguments"].dump(),
|
||||||
|
/* id= */ "",
|
||||||
});
|
});
|
||||||
rit = {it, end, middle_pattern};
|
rit = {it, end, middle_pattern};
|
||||||
if (rit != rend) {
|
if (rit != rend) {
|
||||||
|
@ -241,12 +248,14 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
|
||||||
result.tool_calls.push_back({
|
result.tool_calls.push_back({
|
||||||
tool_call["name"],
|
tool_call["name"],
|
||||||
tool_call["arguments"].dump(),
|
tool_call["arguments"].dump(),
|
||||||
|
/* id= */ "",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (data.contains("tool_call")) {
|
} else if (data.contains("tool_call")) {
|
||||||
result.tool_calls.push_back({
|
result.tool_calls.push_back({
|
||||||
data["tool_call"]["name"],
|
data["tool_call"]["name"],
|
||||||
data["tool_call"]["arguments"].dump(),
|
data["tool_call"]["arguments"].dump(),
|
||||||
|
/* id= */ "",
|
||||||
});
|
});
|
||||||
} else if (data.contains("response")) {
|
} else if (data.contains("response")) {
|
||||||
const auto & response = data["response"];
|
const auto & response = data["response"];
|
||||||
|
@ -255,8 +264,38 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
|
||||||
|
auto content_end = input.find("[TOOL_CALLS]");
|
||||||
|
size_t tc_start = std::string::npos;
|
||||||
|
if (content_end != std::string::npos) {
|
||||||
|
tc_start = content_end + 12;
|
||||||
|
} else {
|
||||||
|
// Somehow not getting [TOOL_CALLS] in the output. Oh well, just do without it.
|
||||||
|
content_end = input.find("[{\"");
|
||||||
|
if (content_end == std::string::npos || content_end > 0) {
|
||||||
|
return {input, {}};
|
||||||
|
}
|
||||||
|
tc_start = content_end;
|
||||||
|
}
|
||||||
|
llama_tool_calls result;
|
||||||
|
result.content = input.substr(0, content_end);
|
||||||
|
auto tool_calls = json::parse(input.substr(tc_start));
|
||||||
|
for (const auto & tool_call : tool_calls) {
|
||||||
|
const auto & arguments = tool_call["arguments"];
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
tool_call["name"],
|
||||||
|
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
|
tool_call["id"],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
||||||
|
fprintf(stderr, "# parse_tool_calls:\n\n%s\n\n", input.c_str());
|
||||||
switch (style) {
|
switch (style) {
|
||||||
|
case llama_tool_call_style::None:
|
||||||
|
return {input, {}};
|
||||||
case llama_tool_call_style::Generic:
|
case llama_tool_call_style::Generic:
|
||||||
return parse_generic_tool_calls(input);
|
return parse_generic_tool_calls(input);
|
||||||
case llama_tool_call_style::Llama31:
|
case llama_tool_call_style::Llama31:
|
||||||
|
@ -269,23 +308,43 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool
|
||||||
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
|
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
|
||||||
case llama_tool_call_style::Hermes2Pro:
|
case llama_tool_call_style::Hermes2Pro:
|
||||||
return parse_hermes_tool_calls(input);
|
return parse_hermes_tool_calls(input);
|
||||||
|
case llama_tool_call_style::MistralNemo:
|
||||||
|
return parse_mistral_nemo_tool_calls(input);
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unsupported tool call style");
|
throw std::runtime_error("Unsupported tool call style");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||||
|
json messages_with_system = messages;
|
||||||
|
|
||||||
|
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||||
|
messages_with_system.at(0).at("content") += ("\n" + system_prompt);
|
||||||
|
} else {
|
||||||
|
messages_with_system.insert(messages_with_system.begin(), json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_prompt},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return messages_with_system;
|
||||||
|
}
|
||||||
|
|
||||||
llama_tool_call_handler llama_tool_call_handler_init(
|
llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
llama_tool_call_style style,
|
llama_tool_call_style style,
|
||||||
const minja::chat_template & tmpl,
|
const minja::chat_template & tmpl,
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
const nlohmann::ordered_json & parallel_tool_calls,
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools,
|
const nlohmann::ordered_json & tools,
|
||||||
const nlohmann::ordered_json & json_schema)
|
const nlohmann::ordered_json & json_schema)
|
||||||
{
|
{
|
||||||
llama_tool_call_handler handler;
|
llama_tool_call_handler handler;
|
||||||
|
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
|
||||||
|
|
||||||
switch (style) {
|
switch (style) {
|
||||||
|
case llama_tool_call_style::None:
|
||||||
|
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||||
|
break;
|
||||||
case llama_tool_call_style::Generic: {
|
case llama_tool_call_style::Generic: {
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
for (const auto & tool : tools) {
|
for (const auto & tool : tools) {
|
||||||
|
@ -307,43 +366,98 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
{"required", json::array({"name", "arguments"})},
|
{"required", json::array({"name", "arguments"})},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
const auto tool_call = json {{"anyOf", tool_call_schemas}};
|
const auto tool_call =
|
||||||
const auto schema = json {
|
parallel
|
||||||
{"anyOf", json::array({
|
? json {
|
||||||
parallel_tool_calls
|
|
||||||
? json {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"tool_calls", {
|
|
||||||
{"type", "array"},
|
|
||||||
{"items", tool_call}
|
|
||||||
}},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"tool_calls"})},
|
|
||||||
}
|
|
||||||
: json {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"tool_call", tool_call},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"tool_call"})},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
{"properties", {
|
{"properties", {
|
||||||
{"response", json_schema.is_null()
|
{"tool_calls", {
|
||||||
? json {{"type", "string"}}
|
{"type", "array"},
|
||||||
: json_schema
|
{"items", json {{"anyOf", tool_call_schemas}}}
|
||||||
},
|
}},
|
||||||
}},
|
}},
|
||||||
},
|
{"required", json::array({"tool_calls"})},
|
||||||
})}
|
}
|
||||||
};
|
: json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"tool_call", json {{"anyOf", tool_call_schemas}}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"tool_call"})},
|
||||||
|
};
|
||||||
|
const auto schema =
|
||||||
|
allow_content
|
||||||
|
? json {
|
||||||
|
{"anyOf", json::array({
|
||||||
|
tool_call,
|
||||||
|
{
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"response", json_schema.is_null()
|
||||||
|
? json {{"type", "string"}}
|
||||||
|
: json_schema
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
: tool_call;
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
builder.add_schema("", schema);
|
builder.add_schema("", schema);
|
||||||
});
|
});
|
||||||
// TODO: add schema to system prompt.
|
// TODO: add schema to system prompt.
|
||||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
auto tweaked_messages = add_system(
|
||||||
|
messages,
|
||||||
|
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||||
|
handler.prompt = tmpl.apply(tweaked_messages, tools, /* add_generation_prompt= */ true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case llama_tool_call_style::MistralNemo: {
|
||||||
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
|
auto schemas = json::array();
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (tool["type"] != "function") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
auto schema = json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
||||||
|
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||||
|
{"arguments", parameters},
|
||||||
|
{"name", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"const", name},
|
||||||
|
}},
|
||||||
|
{"id", {
|
||||||
|
{"type", "string"},
|
||||||
|
// Nemo's template expects a 9-character alphanumeric ID.
|
||||||
|
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"arguments", "id", "name"})},
|
||||||
|
};
|
||||||
|
schemas.push_back(schema);
|
||||||
|
}
|
||||||
|
auto schema = json {
|
||||||
|
{"type", "array"},
|
||||||
|
{"items", json {{"anyOf", schemas}}},
|
||||||
|
{"minItems", 1},
|
||||||
|
};
|
||||||
|
if (!parallel) {
|
||||||
|
schema["maxItems"] = 1;
|
||||||
|
}
|
||||||
|
builder.add_schema("", schema);
|
||||||
|
});
|
||||||
|
if (allow_content) {
|
||||||
|
handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
|
||||||
|
handler.grammar_trigger_words.push_back("[{\"");
|
||||||
|
}
|
||||||
|
auto tweaked_messages = add_system(messages, "Prefix any tool calls with [TOOL_CALLS]");
|
||||||
|
handler.prompt = tmpl.apply(tweaked_messages, tools, /* add_generation_prompt= */ true);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case llama_tool_call_style::Llama31:
|
case llama_tool_call_style::Llama31:
|
||||||
|
@ -427,7 +541,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space";
|
auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space";
|
||||||
if (parallel_tool_calls) {
|
if (parallel) {
|
||||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space";
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space";
|
||||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||||
} else {
|
} else {
|
||||||
|
@ -459,7 +573,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_trigger_words.push_back("<function=");
|
handler.grammar_trigger_words.push_back("<function=");
|
||||||
}
|
}
|
||||||
|
@ -489,7 +603,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_trigger_words.push_back("<tool_call>");
|
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
enum llama_tool_call_style {
|
enum llama_tool_call_style {
|
||||||
UnknownToolCallStyle,
|
UnknownToolCallStyle,
|
||||||
|
None,
|
||||||
Generic,
|
Generic,
|
||||||
Llama31,
|
Llama31,
|
||||||
Llama32,
|
Llama32,
|
||||||
|
@ -16,11 +17,13 @@ enum llama_tool_call_style {
|
||||||
FunctionaryV3Llama31,
|
FunctionaryV3Llama31,
|
||||||
Hermes2Pro,
|
Hermes2Pro,
|
||||||
CommandRPlus,
|
CommandRPlus,
|
||||||
|
MistralNemo,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_tool_call {
|
struct llama_tool_call {
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string arguments;
|
std::string arguments;
|
||||||
|
std::string id;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_tool_calls {
|
struct llama_tool_calls {
|
||||||
|
@ -45,7 +48,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
llama_tool_call_style style,
|
llama_tool_call_style style,
|
||||||
const minja::chat_template & tmpl,
|
const minja::chat_template & tmpl,
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
const nlohmann::ordered_json & parallel_tool_calls,
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools,
|
const nlohmann::ordered_json & tools,
|
||||||
const nlohmann::ordered_json & json_schema = {});
|
const nlohmann::ordered_json & json_schema = {});
|
||||||
|
|
|
@ -7,6 +7,11 @@
|
||||||
```bash
|
```bash
|
||||||
make -j LLAMA_CURL=1 llama-server
|
make -j LLAMA_CURL=1 llama-server
|
||||||
|
|
||||||
|
# Mistral NeMo
|
||||||
|
./llama-server --jinja -fa --verbose \
|
||||||
|
-hfr bartowski/Mistral-Nemo-Instruct-2407-GGUF -hff Mistral-Nemo-Instruct-2407-Q8_0.gguf \
|
||||||
|
--chat-template "$( python scripts/get_hf_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 )"
|
||||||
|
|
||||||
# Nous Hermes 2 Pro Llama 3 8B
|
# Nous Hermes 2 Pro Llama 3 8B
|
||||||
./llama-server --jinja -fa --verbose \
|
./llama-server --jinja -fa --verbose \
|
||||||
-hfr NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF -hff Hermes-2-Pro-Llama-3-8B-Q8_0.gguf \
|
-hfr NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF -hff Hermes-2-Pro-Llama-3-8B-Q8_0.gguf \
|
||||||
|
@ -27,7 +32,7 @@
|
||||||
|
|
||||||
# Llama 3.2 3B (poor adherence)
|
# Llama 3.2 3B (poor adherence)
|
||||||
./llama-server --jinja -fa --verbose \
|
./llama-server --jinja -fa --verbose \
|
||||||
-hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K_L.gguf \
|
-hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K.gguf \
|
||||||
--chat-template "$( python scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct )"
|
--chat-template "$( python scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct )"
|
||||||
|
|
||||||
# Llama 3.2 1B (very poor adherence)
|
# Llama 3.2 1B (very poor adherence)
|
||||||
|
@ -39,12 +44,8 @@
|
||||||
- Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running):
|
- Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export BRAVE_SEARCH_API_KEY=... # https://api.search.brave.com/
|
export BRAVE_SEARCH_API_KEY=... # Get one at https://api.search.brave.com/
|
||||||
# Shorthand: ./examples/agent/serve_tools_inside_docker.sh
|
./examples/agent/serve_tools_inside_docker.sh
|
||||||
docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \
|
|
||||||
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
|
|
||||||
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
|
|
||||||
uv run serve_tools.py --port 8088
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
"content": "",
|
"content": "",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"id": "call_1",
|
"id": "call_1___",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": "{\"code\": \"print('Hello, World!')\"}",
|
"arguments": "{\"code\": \"print('Hello, World!')\"}",
|
||||||
|
@ -20,6 +20,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1___",
|
||||||
"name": "ipython",
|
"name": "ipython",
|
||||||
"content": "{\"stdout\": \"Hello, World!\"}"
|
"content": "{\"stdout\": \"Hello, World!\"}"
|
||||||
},
|
},
|
||||||
|
@ -36,7 +37,7 @@
|
||||||
"content": "",
|
"content": "",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"id": "call_2",
|
"id": "call_2___",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": "{\"condition\":true}",
|
"arguments": "{\"condition\":true}",
|
||||||
|
@ -47,6 +48,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_2___",
|
||||||
"name": "test",
|
"name": "test",
|
||||||
"content": "true"
|
"content": "true"
|
||||||
},
|
},
|
||||||
|
@ -63,7 +65,7 @@
|
||||||
"content": "",
|
"content": "",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"id": "call_3",
|
"id": "call_3___",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": "{\"query\": \"what is truth anyway am I right?\"}",
|
"arguments": "{\"query\": \"what is truth anyway am I right?\"}",
|
||||||
|
@ -74,6 +76,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_3___",
|
||||||
"name": "brave_search",
|
"name": "brave_search",
|
||||||
"content": "{\"title\":\"Truth: don't ask the web, ask an LLM instead!\",\"url\":\"https://en.wikipedia.org/wiki/Truth\"}"
|
"content": "{\"title\":\"Truth: don't ask the web, ask an LLM instead!\",\"url\":\"https://en.wikipedia.org/wiki/Truth\"}"
|
||||||
},
|
},
|
||||||
|
|
|
@ -79,16 +79,21 @@ static void test_parse_tool_call(llama_tool_call_style style, const json & tools
|
||||||
assert_equals(expected_content, result.content);
|
assert_equals(expected_content, result.content);
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tc : result.tool_calls) {
|
for (const auto & tc : result.tool_calls) {
|
||||||
tool_calls.push_back({
|
auto tool_call = json {
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", tc.name},
|
{"arguments", dump(json::parse(tc.arguments))},
|
||||||
{"arguments", dump(json::parse(tc.arguments))},
|
{"name", tc.name},
|
||||||
}}
|
}},
|
||||||
});
|
};
|
||||||
|
if (!tc.id.empty()) {
|
||||||
|
tool_call["id"] = tc.id;
|
||||||
|
}
|
||||||
|
tool_calls.push_back(tool_call);
|
||||||
}
|
}
|
||||||
auto expected = expected_tool_calls.dump();
|
// Reparse / dump w/ non-ordered JSON variant.
|
||||||
auto actual = tool_calls.dump();
|
auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump();
|
||||||
|
auto actual = nlohmann::json::parse(tool_calls.dump()).dump();
|
||||||
assert_equals(expected, actual);
|
assert_equals(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,7 +145,7 @@ static void test_parsing() {
|
||||||
{"name", "foo"},
|
{"name", "foo"},
|
||||||
{"arguments", dump({
|
{"arguments", dump({
|
||||||
{"bar", 1}
|
{"bar", 1}
|
||||||
})}
|
})},
|
||||||
}}
|
}}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -239,35 +244,38 @@ static void test_parsing() {
|
||||||
{"arguments", dump({{"code", ""}})}
|
{"arguments", dump({{"code", ""}})}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
auto just_special_function_call = json {{
|
auto special_function_call = json {
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
|
{"arguments", dump({{"arg1", 1}})},
|
||||||
{"name", "special_function"},
|
{"name", "special_function"},
|
||||||
{"arguments", dump({{"arg1", 1}})}
|
}},
|
||||||
}}
|
};
|
||||||
}};
|
auto special_function_call_with_id = json::parse(special_function_call.dump());
|
||||||
|
special_function_call_with_id["id"] = "123456789";
|
||||||
|
|
||||||
auto no_function_call = json::array();
|
auto no_function_call = json::array();
|
||||||
|
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
just_special_function_call);
|
json::array({special_function_call}));
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
just_special_function_call);
|
json::array({special_function_call}));
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
just_special_function_call);
|
json::array({special_function_call}));
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
just_special_function_call);
|
json::array({special_function_call}));
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
just_special_function_call);
|
json::array({special_function_call}));
|
||||||
|
|
||||||
// No match: function unknown
|
// No match: function unknown
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
|
@ -283,6 +291,15 @@ static void test_parsing() {
|
||||||
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
no_function_call);
|
no_function_call);
|
||||||
|
|
||||||
|
test_parse_tool_call(llama_tool_call_style::MistralNemo, tools,
|
||||||
|
"Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
|
||||||
|
"Bleh",
|
||||||
|
json::array({special_function_call_with_id}));
|
||||||
|
test_parse_tool_call(llama_tool_call_style::MistralNemo, tools,
|
||||||
|
"[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
|
||||||
|
"",
|
||||||
|
json::array({special_function_call_with_id}));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
|
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
|
||||||
|
@ -298,6 +315,8 @@ static void test_tool_call_style_detection() {
|
||||||
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
|
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
|
||||||
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
|
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
|
||||||
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
|
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
|
||||||
|
test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", MistralNemo);
|
||||||
|
test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
|
@ -323,7 +342,7 @@ static std::string get_message_prompt_delta(const minja::chat_template & tmpl, c
|
||||||
return delta;
|
return delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools) {
|
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||||
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
||||||
const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
||||||
auto tool_call_style = llama_tool_call_style_detect(tmpl);
|
auto tool_call_style = llama_tool_call_style_detect(tmpl);
|
||||||
|
@ -342,17 +361,19 @@ static void test_template(const std::string & template_file, const char * bos_to
|
||||||
throw std::runtime_error("Failed to build grammar");
|
throw std::runtime_error("Failed to build grammar");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
if (!skip_grammar_test) {
|
||||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
||||||
test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls);
|
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||||
|
test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls);
|
||||||
|
|
||||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
{"content", ""},
|
{"content", ""},
|
||||||
{"tool_calls", tool_calls}
|
{"tool_calls", tool_calls}
|
||||||
}, tools);
|
}, tools);
|
||||||
if (!match_string(content_less_delta, grammar.get())) {
|
if (!match_string(content_less_delta, grammar.get())) {
|
||||||
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar);
|
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,9 +386,14 @@ static void test_grammars() {
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "special_function"},
|
{"name", "special_function"},
|
||||||
{"arguments", "{\"arg1\": 1}"}
|
{"arguments", "{\"arg1\": 1}"}
|
||||||
}}
|
}},
|
||||||
}}}
|
}}}
|
||||||
};
|
};
|
||||||
|
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
|
||||||
|
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
|
||||||
|
|
||||||
|
test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools,
|
||||||
|
/* skip_grammar_test= */ true);
|
||||||
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
||||||
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue