From 61b63705dcf39997f2c934ce59b8bde5d1adfbd6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Mar 2024 19:55:19 +0200 Subject: [PATCH] server : refactor system prompt update at start --- examples/server/server.cpp | 38 +++++++++++++++++--------------------- examples/server/utils.hpp | 11 +++++++++++ 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a092ee61c..77b5321e9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -115,6 +115,7 @@ struct server_params { std::string hostname = "127.0.0.1"; std::string public_path = "examples/server/public"; std::string chat_template = ""; + std::string system_prompt = ""; std::vector api_keys; @@ -1024,7 +1025,7 @@ struct llama_server_context { system_need_update = false; } - void system_prompt_process(const json & sys_props) { + void system_prompt_set(const json & sys_props) { system_prompt = sys_props.value("prompt", ""); name_user = sys_props.value("anti_prompt", ""); name_assistant = sys_props.value("assistant_name", ""); @@ -1418,7 +1419,7 @@ struct llama_server_context { } if (task.data.contains("system_prompt")) { - system_prompt_process(task.data["system_prompt"]); + system_prompt_set(task.data["system_prompt"]); // reset cache_tokens for all slots for (server_slot & slot : slots) { @@ -1974,7 +1975,10 @@ struct llama_server_context { } for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); + result.probs.push_back({ + cur_p.data[i].id, + cur_p.data[i].p + }); } if (!process_token(result, slot)) { @@ -2088,7 +2092,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf("\n"); } -static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params, llama_server_context & llama) { +static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params) { gpt_params default_params; server_params default_sparams; @@ -2402,13 +2406,13 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - std::string systm_content; + std::string system_prompt; std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(systm_content) + std::back_inserter(system_prompt) ); - llama.system_prompt_process(json::parse(systm_content)); + sparams.system_prompt = system_prompt; } else if (arg == "-ctk" || arg == "--cache-type-k") { params.cache_type_k = argv[++i]; } else if (arg == "-ctv" || arg == "--cache-type-v") { @@ -2506,18 +2510,6 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } } -static json format_tokenizer_response(const std::vector & tokens) { - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(const std::string & content) { - return json { - {"content", content} - }; -} - static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port if (req.path == "/v1/health" || req.path == "/v1/completions") { @@ -2558,13 +2550,17 @@ int main(int argc, char ** argv) { log_disable(); #endif // own arguments required by this example - gpt_params params; + gpt_params params; server_params sparams; // struct that contains llama context and inference llama_server_context llama; - server_params_parse(argc, argv, sparams, params, llama); + server_params_parse(argc, argv, sparams, params); + + if (!sparams.system_prompt.empty()) { + llama.system_prompt_set(json::parse(sparams.system_prompt)); + } if (params.model_alias == "unknown") { params.model_alias = params.model; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 7e08209fc..df0a27782 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -531,3 +531,14 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static json format_tokenizer_response(const std::vector & tokens) { + return json { + {"tokens", tokens} + }; +} + +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +}