diff --git a/examples/server/functionary-test.cpp b/examples/server/functionary-test.cpp index 15763fcc1..15100860a 100644 --- a/examples/server/functionary-test.cpp +++ b/examples/server/functionary-test.cpp @@ -93,11 +93,7 @@ std::string test_oai_input_json = R"( )"; -std::string test_response = R"(<|from|>assistant -<|recipient|>all -<|content|>I will get the price of 2 cars and compare -<|from|>assistant -<|recipient|>get_car_price +std::string test_response = R"(get_car_price <|content|>{"car_name": "Song"} <|from|>assistant <|recipient|>get_car_price @@ -105,7 +101,7 @@ std::string test_response = R"(<|from|>assistant int main() { auto test_oai_input = json::parse(test_oai_input_json); - auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input); + auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input, true); std::cout << "\n" << prompt << "\n"; std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n"; diff --git a/examples/server/functionary.hpp b/examples/server/functionary.hpp index 338d8b1f9..95e771d40 100644 --- a/examples/server/functionary.hpp +++ b/examples/server/functionary.hpp @@ -9,6 +9,7 @@ using json = nlohmann::json; /** * Integration with functionary model: https://github.com/MeetKai/functionary + * Based on my research: https://github.com/ggerganov/llama.cpp/issues/5588 * * A typical flow is: * - Step 1: user send request to model @@ -201,7 +202,7 @@ inline std::string serialize_function(function_def & fn) { /////////////////////////////////////////// // Main hooks, to be called in oai.hpp -inline std::string convert_oai_to_prompt(const json & body) { +inline std::string convert_oai_to_prompt(const json & body, bool add_ass, bool allow_tool = true) { std::stringstream ss; // convert function definitions std::vector tools = json_value(body, "tools", json::array()); @@ -224,30 +225,48 @@ inline std::string convert_oai_to_prompt(const json & body) { // convert history std::vector messages = json_value(body, "messages", json::array()); for (auto & msg_json : messages) { + // TODO: how to detect where to put "<|stop|>"? if (msg_json.count("tool_calls")) { // assistant request to function call, now re-passed to history std::vector tool_calls = msg_json["tool_calls"]; - for (auto & tc : tool_calls) { + for (size_t i = 0; i < tool_calls.size(); i++) { + auto & tc = tool_calls[i]; message msg; msg.from = tc["function"]["name"]; msg.content = tc["function"]["arguments"]; + msg.has_stop = i == tool_calls.size() - 1; // last msg ss << msg.to_prompt(); } } else { // all other types of message message msg(msg_json); + msg.has_stop = msg.from == "assistant"; // add stop if this is single text message from assistant (not contains tool_calls) ss << msg.to_prompt(); } } + // add trailing assistant prompt + if (add_ass) { + ss << "<|from|>assistant\n<|recipient|>"; + if (!allow_tool) { + ss << FUNCTIONARY_RECIP_NONE; + } + } return ss.str(); } -// be careful, the assistant output does not have "<|from|>assistant", you need to add it yourself! inline json convert_response_to_oai_choices(const std::string & content) { + std::string input_full = content; std::string text_response; json tool_calls = json::array(); // parse all turns - std::vector turns = str_split(content, "<|from|>"); + std::vector turns = str_split(input_full, "<|from|>"); + if (!turns.empty()) { + // first turn may not have the assistant tag (because it was part of the prompt added by "add_ass"), we will put it back to parse the message + // the "<|from|>" will be added later + if (turns[0].find("<|recipient|>") == std::string::npos) { + turns[0] = "assistant\n<|recipient|>" + turns[0]; + } + } for (auto & turn : turns) { std::string turn_full = "<|from|>" + turn; message msg(turn_full); diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index 1a2792702..0cc06c2f1 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -35,7 +35,7 @@ inline static json oaicompat_completion_params_parse( llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("unknown")); llama_params["prompt"] = enable_tool_calls - ? llama_functionary::convert_oai_to_prompt(body) + ? llama_functionary::convert_oai_to_prompt(body, true) : format_chat(model, chat_template, body["messages"]); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); @@ -67,8 +67,10 @@ inline static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } - // Ensure there is ChatML-specific end sequence among stop words - llama_params["stop"].push_back("<|im_end|>"); + llama_params["stop"].push_back(enable_tool_calls + ? "<|stop|>" // functionary-specific: this model uses "<|stop|>" instead of "" + : "<|im_end|>" // Ensure there is ChatML-specific end sequence among stop words + ); return llama_params; } @@ -104,7 +106,7 @@ inline static json format_final_response_oaicompat( : json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"message", json{{"content", content}, - {"role", "assistant"}}}}}); + {"role", "assistant"}}}}}); } std::time_t t = std::time(0); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6626b5fdd..c7d9a8ef2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2773,15 +2773,16 @@ int main(int argc, char **argv) LOG_INFO("model loaded", {}); } - if (sparams.chat_template.empty()) { // custom chat template is not supplied - // check if the template comes with the model is supported by us - llama.validate_model_chat_template(sparams); - } - // Check tool_call ability sparams.enable_tool_calls = check_model_support_tool_calls(llama.model); if (sparams.enable_tool_calls) { - LOG_VERBOSE("Current model supports functionary tool_calls", {}); + LOG_INFO("Current model supports functionary tool_calls", {}); + } + + // custom chat template is not supplied + if (sparams.chat_template.empty() && !sparams.enable_tool_calls) { + // check if the template comes with the model is supported by us + llama.validate_model_chat_template(sparams); } // Middleware for API key validation diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index ef59fe63a..2c0bb18b6 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -219,6 +219,7 @@ inline bool check_model_support_tool_calls(const struct llama_model * model) { if (res < 0) { return false; // no template in model } else { + model_template.resize(res); std::string tmpl(model_template.data(), model_template.size()); return tmpl.find("<|recipient|>") != std::string::npos; }