diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6c7fcd176..54e8448a4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -45,6 +45,12 @@ struct server_params int32_t write_timeout = 600; }; +struct chat_formatter +{ + std::function)> conversion; + std::vector stop_keys; +}; + static bool server_verbose = false; #if SERVER_VERBOSE != 1 @@ -64,7 +70,7 @@ static bool server_verbose = false; #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) -json oaicompat_completion_params_parse(const json &body); +json oaicompat_completion_params_parse(const json &body, const chat_formatter formatter); std::string format_chatml(std::vector messages); @@ -524,6 +530,7 @@ struct llama_server_context clip_ctx *clp_ctx = nullptr; gpt_params params; + json formatter_params = nullptr; llama_batch batch; @@ -2376,6 +2383,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, ); llama.process_system_prompt_data(json::parse(systm_content)); } + else if (arg == "-cfmt" || arg == "--chat-format-file"){ + if (++i >= argc) + { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + llama.formatter_params = json::parse(file); + fprintf(stdout, "Loaded chat template '%s'\n", llama.formatter_params.dump().c_str()); + } else if(arg == "--mmproj") { if (++i >= argc) @@ -2477,6 +2499,25 @@ static std::string gen_chatcmplid() return chatcmplid.str(); } +std::string format_chat(std::vector messages, json &formatter) +{ + std::ostringstream prompt; + + for (auto it = messages.begin(); it != messages.end(); ++it) { + auto role = json_value(*it, "role", std::string("unknown")); + LOG_TEE("\n Message received: %s\nWith role: %s", (*it).dump().c_str(), role.c_str()); + auto prefix = formatter.at(role).value("prefix", ""); + auto suffix = formatter.at(role).value("suffix", ""); + prompt << prefix + << json_value(*it, "content", std::string("")) + << suffix + << json_value(formatter, "delimiter", std::string(" ")); + } + prompt << formatter.at("assistant").value("prefix", ""); + + return prompt.str(); +} + std::string format_chatml(std::vector messages) { std::ostringstream chatml_msgs; @@ -2495,7 +2536,9 @@ std::string format_chatml(std::vector messages) /* llama.cpp completion api semantics */ json oaicompat_completion_params_parse( - const json &body /* openai api json semantics */) + const json &body, /* openai api json semantics */ + const chat_formatter formatter +) { json llama_params; @@ -2510,7 +2553,8 @@ json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("uknown")); - llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + // llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + llama_params["prompt"] = formatter.conversion(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); @@ -2542,7 +2586,9 @@ json oaicompat_completion_params_parse( } // Ensure there is ChatML-specific end sequence among stop words - llama_params["stop"].push_back("<|im_end|>"); + for(auto& stop : formatter.stop_keys){ + llama_params["stop"].push_back(stop); + } return llama_params; } @@ -2964,7 +3010,42 @@ int main(int argc, char **argv) if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(json::parse(req.body)); + + auto formatter_json = llama.formatter_params; + chat_formatter formatter; + if (formatter_json == nullptr){ + formatter.conversion = [](std::vector messages){return format_chatml(messages);}; + formatter.stop_keys = std::vector(1, "<|im_end|>"); + } + else + { + formatter.conversion = [&formatter_json](std::vector messages){return format_chat(messages, formatter_json);}; + // need json parsing the keys + auto items = formatter_json.items(); + formatter.stop_keys = std::vector(2); + for(auto it = items.begin(); it != items.end(); ++it){ + if((*it).key() == "delimiter"){ + continue; + } + // We try to get stop token and store non-duplicate ones + auto stop_token = (*it).value().value("suffix", ""); + if(stop_token != "") + { + if(std::find(formatter.stop_keys.begin(), formatter.stop_keys.end(), stop_token) == formatter.stop_keys.end()) + { + formatter.stop_keys.push_back(stop_token); + } + } + } + auto r = std::accumulate( + formatter.stop_keys.begin() + 1, + formatter.stop_keys.end(), + formatter.stop_keys[0], + [](const std::string a, const std::string b){ return a + ", " + b; } + ); + } + + json data = oaicompat_completion_params_parse(json::parse(req.body), formatter); const int task_id = llama.request_completion(data, false, false, -1);