diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 58e119a3b..190095073 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -41,6 +41,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -176,6 +177,43 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (json { + {"arg1", "some_value"}, + }).dump()}, + }}, + }, + })}, + }; + const json tools; + auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + } else { + throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); + } + } else { + + } + tool_call_example_ = full.substr(prefix.size()); + } } const std::string & source() const { return source_; } @@ -229,7 +267,17 @@ class chat_template { }; auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + json adjusted_messages; + if (needs_tools_in_system) { + adjusted_messages = add_system(messages, + "\n\n" + "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + } else { + adjusted_messages = messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());