tool-call: Log tool call style name, ensure returned content not null

This commit is contained in:
ochafik 2024-10-22 23:41:47 +01:00
parent a4f12a4594
commit fc80ad20ce
4 changed files with 25 additions and 3 deletions

View file

@ -12,6 +12,27 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
std::string llama_tool_call_style_name(llama_tool_call_style style) {
switch (style) {
case llama_tool_call_style::Generic:
return "Generic";
case llama_tool_call_style::Llama31:
return "Llama-3.1";
case llama_tool_call_style::Llama32:
return "Llama-3.2";
case llama_tool_call_style::FunctionaryV3Llama3:
return "FunctionaryV3Llama3";
case llama_tool_call_style::FunctionaryV3Llama31:
return "FunctionaryV3Llama3.1";
case llama_tool_call_style::Hermes2Pro:
return "Hermes2Pro";
case llama_tool_call_style::CommandRPlus:
return "CommandRPlus";
default:
return "Unknown";
}
}
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) { llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) {
const auto & src = chat_template.source(); const auto & src = chat_template.source();

View file

@ -35,6 +35,8 @@ struct llama_tool_call_handler {
std::vector<std::string> additional_stop_words; std::vector<std::string> additional_stop_words;
}; };
std::string llama_tool_call_style_name(llama_tool_call_style style);
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template); llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template);
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);

View file

@ -3031,6 +3031,7 @@ int main(int argc, char ** argv) {
static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
static auto tool_call_style = llama_tool_call_style_detect(chat_template); static auto tool_call_style = llama_tool_call_style_detect(chat_template);
LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str());
json data; json data;
try { try {

View file

@ -468,9 +468,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content); parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content);
if (!parsed_tool_calls.tool_calls.empty()) { if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls"; finish_reason = "tool_calls";
if (!parsed_tool_calls.content.empty()) { message_content = parsed_tool_calls.content;
message_content = parsed_tool_calls.content;
}
tool_calls = json::array(); tool_calls = json::array();
for (const auto & tc : parsed_tool_calls.tool_calls) { for (const auto & tc : parsed_tool_calls.tool_calls) {
tool_calls.push_back({ tool_calls.push_back({