From 33b8159ef0c475e241fa069000f05628a8eb79ac Mon Sep 17 00:00:00 2001 From: Peter Nagymathe Date: Fri, 29 Dec 2023 12:13:42 +0000 Subject: [PATCH 1/3] preliminary method for general chat format handling in server --- examples/server/server.cpp | 93 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c5035e202..2f0a85f82 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -44,6 +44,12 @@ struct server_params int32_t write_timeout = 600; }; +struct chat_formatter +{ + std::function)> conversion; + std::vector stop_keys; +}; + static bool server_verbose = false; #if SERVER_VERBOSE != 1 @@ -63,7 +69,7 @@ static bool server_verbose = false; #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) -json oaicompat_completion_params_parse(const json &body); +json oaicompat_completion_params_parse(const json &body, const chat_formatter formatter); std::string format_chatml(std::vector messages); @@ -514,6 +520,7 @@ struct llama_server_context clip_ctx *clp_ctx = nullptr; gpt_params params; + json formatter_params = nullptr; llama_batch batch; @@ -2358,6 +2365,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, ); llama.process_system_prompt_data(json::parse(systm_content)); } + else if (arg == "-cfmt" || arg == "--chat-format-file"){ + if (++i >= argc) + { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + llama.formatter_params = json::parse(file); + fprintf(stdout, "Loaded chat format '%s'\n", llama.formatter_params.dump().c_str()); + } else if(arg == "--mmproj") { if (++i >= argc) @@ -2412,6 +2434,25 @@ static std::string gen_chatcmplid() return chatcmplid.str(); } +std::string format_chat(std::vector messages, json &formatter) +{ + std::ostringstream prompt; + + for (auto it = messages.begin(); it != messages.end(); ++it) { + auto role = json_value(*it, "role", std::string("unknown")); + LOG_TEE("\n Message received: %s\nWith role: %s", (*it).dump().c_str(), role.c_str()); + auto prefix = formatter.at(role).value("prefix", ""); + auto suffix = formatter.at(role).value("suffix", ""); + prompt << prefix + << json_value(*it, "content", std::string("")) + << suffix + << json_value(formatter, "delimiter", std::string(" ")); + } + prompt << formatter.at("assistant").value("prefix", ""); + + return prompt.str(); +} + std::string format_chatml(std::vector messages) { std::ostringstream chatml_msgs; @@ -2430,7 +2471,9 @@ std::string format_chatml(std::vector messages) /* llama.cpp completion api semantics */ json oaicompat_completion_params_parse( - const json &body /* openai api json semantics */) + const json &body, /* openai api json semantics */ + const chat_formatter formatter +) { json llama_params; @@ -2445,7 +2488,8 @@ json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("uknown")); - llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + // llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + llama_params["prompt"] = formatter.conversion(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); @@ -2899,7 +2943,48 @@ int main(int argc, char **argv) if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(json::parse(req.body)); + + auto formatter_json = llama.formatter_params; + chat_formatter formatter; + if (formatter_json == nullptr){ + formatter.conversion = [](std::vector messages){return format_chatml(messages);}; + formatter.stop_keys = std::vector(1, "<|im_end|>"); + } + else + { + formatter.conversion = [&formatter_json](std::vector messages){return format_chat(messages, formatter_json);}; + // need json parsing the keys + // fprintf(stdout, "\n attempt to get json keys"); + auto items = formatter_json.items(); + formatter.stop_keys = std::vector(2); + for(auto it = items.begin(); it != items.end(); ++it){ + if((*it).key() == "delimiter"){ + continue; + // fprintf(stdout, "\n skipped delimiter"); + } + // We try to get stop token and store non-duplicate ones + auto stop_token = (*it).value().value("suffix", ""); + if(stop_token != "") + { + // fprintf(stdout, "\n try add stop token %s", stop_token.c_str()); + if(std::find(formatter.stop_keys.begin(), formatter.stop_keys.end(), stop_token) == formatter.stop_keys.end()) + { + // fprintf(stdout, "\n add stop token %s", stop_token.c_str()); + formatter.stop_keys.push_back(stop_token); + } + } + } + auto r = std::accumulate( + formatter.stop_keys.begin() + 1, + formatter.stop_keys.end(), + formatter.stop_keys[0], + [](const std::string a, const std::string b){ return a + ", " + b; } + ); + fprintf(stdout, "\n Stop Keys: %s", r.c_str()); + } + + json data = oaicompat_completion_params_parse(json::parse(req.body), formatter); + // json data = oaicompat_completion_params_parse(json::parse(req.body)); const int task_id = llama.request_completion(data, false, false, -1); From fe6630c23e1415d6840698444986c6817d714ada Mon Sep 17 00:00:00 2001 From: Peter Nagymathe Date: Sat, 30 Dec 2023 02:43:24 +0000 Subject: [PATCH 2/3] remove prints --- examples/server/server.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2f0a85f82..cebb2842d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2378,7 +2378,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } llama.formatter_params = json::parse(file); - fprintf(stdout, "Loaded chat format '%s'\n", llama.formatter_params.dump().c_str()); + fprintf(stdout, "Loaded chat template '%s'\n", llama.formatter_params.dump().c_str()); } else if(arg == "--mmproj") { @@ -2954,22 +2954,18 @@ int main(int argc, char **argv) { formatter.conversion = [&formatter_json](std::vector messages){return format_chat(messages, formatter_json);}; // need json parsing the keys - // fprintf(stdout, "\n attempt to get json keys"); auto items = formatter_json.items(); formatter.stop_keys = std::vector(2); for(auto it = items.begin(); it != items.end(); ++it){ if((*it).key() == "delimiter"){ continue; - // fprintf(stdout, "\n skipped delimiter"); } // We try to get stop token and store non-duplicate ones auto stop_token = (*it).value().value("suffix", ""); if(stop_token != "") { - // fprintf(stdout, "\n try add stop token %s", stop_token.c_str()); if(std::find(formatter.stop_keys.begin(), formatter.stop_keys.end(), stop_token) == formatter.stop_keys.end()) { - // fprintf(stdout, "\n add stop token %s", stop_token.c_str()); formatter.stop_keys.push_back(stop_token); } } @@ -2980,11 +2976,9 @@ int main(int argc, char **argv) formatter.stop_keys[0], [](const std::string a, const std::string b){ return a + ", " + b; } ); - fprintf(stdout, "\n Stop Keys: %s", r.c_str()); } json data = oaicompat_completion_params_parse(json::parse(req.body), formatter); - // json data = oaicompat_completion_params_parse(json::parse(req.body)); const int task_id = llama.request_completion(data, false, false, -1); From 9c3e7e0c771aca641fc21843d6af54e995f2436a Mon Sep 17 00:00:00 2001 From: Peter Nagymathe Date: Sat, 30 Dec 2023 02:57:23 +0000 Subject: [PATCH 3/3] apply custom stop tokens --- examples/server/server.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cebb2842d..fb6c2d1bd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2521,7 +2521,9 @@ json oaicompat_completion_params_parse( } // Ensure there is ChatML-specific end sequence among stop words - llama_params["stop"].push_back("<|im_end|>"); + for(auto& stop : formatter.stop_keys){ + llama_params["stop"].push_back(stop); + } return llama_params; }