server : refactor system prompt update at start

This commit is contained in:
Georgi Gerganov 2024-03-05 19:55:19 +02:00
parent 4a2d5f63f2
commit 61b63705dc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 28 additions and 21 deletions

View file

@ -115,6 +115,7 @@ struct server_params {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
std::string chat_template = ""; std::string chat_template = "";
std::string system_prompt = "";
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
@ -1024,7 +1025,7 @@ struct llama_server_context {
system_need_update = false; system_need_update = false;
} }
void system_prompt_process(const json & sys_props) { void system_prompt_set(const json & sys_props) {
system_prompt = sys_props.value("prompt", ""); system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", ""); name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", ""); name_assistant = sys_props.value("assistant_name", "");
@ -1418,7 +1419,7 @@ struct llama_server_context {
} }
if (task.data.contains("system_prompt")) { if (task.data.contains("system_prompt")) {
system_prompt_process(task.data["system_prompt"]); system_prompt_set(task.data["system_prompt"]);
// reset cache_tokens for all slots // reset cache_tokens for all slots
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
@ -1974,7 +1975,10 @@ struct llama_server_context {
} }
for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); result.probs.push_back({
cur_p.data[i].id,
cur_p.data[i].p
});
} }
if (!process_token(result, slot)) { if (!process_token(result, slot)) {
@ -2088,7 +2092,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf("\n"); printf("\n");
} }
static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params, llama_server_context & llama) { static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params) {
gpt_params default_params; gpt_params default_params;
server_params default_sparams; server_params default_sparams;
@ -2402,13 +2406,13 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
std::string systm_content; std::string system_prompt;
std::copy( std::copy(
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(), std::istreambuf_iterator<char>(),
std::back_inserter(systm_content) std::back_inserter(system_prompt)
); );
llama.system_prompt_process(json::parse(systm_content)); sparams.system_prompt = system_prompt;
} else if (arg == "-ctk" || arg == "--cache-type-k") { } else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i]; params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") { } else if (arg == "-ctv" || arg == "--cache-type-v") {
@ -2506,18 +2510,6 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
} }
} }
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(const std::string & content) {
return json {
{"content", content}
};
}
static void log_server_request(const httplib::Request & req, const httplib::Response & res) { static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
// skip GH copilot requests when using default port // skip GH copilot requests when using default port
if (req.path == "/v1/health" || req.path == "/v1/completions") { if (req.path == "/v1/health" || req.path == "/v1/completions") {
@ -2558,13 +2550,17 @@ int main(int argc, char ** argv) {
log_disable(); log_disable();
#endif #endif
// own arguments required by this example // own arguments required by this example
gpt_params params; gpt_params params;
server_params sparams; server_params sparams;
// struct that contains llama context and inference // struct that contains llama context and inference
llama_server_context llama; llama_server_context llama;
server_params_parse(argc, argv, sparams, params, llama); server_params_parse(argc, argv, sparams, params);
if (!sparams.system_prompt.empty()) {
llama.system_prompt_set(json::parse(sparams.system_prompt));
}
if (params.model_alias == "unknown") { if (params.model_alias == "unknown") {
params.model_alias = params.model; params.model_alias = params.model;

View file

@ -531,3 +531,14 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res; return res;
} }
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(const std::string & content) {
return json {
{"content", content}
};
}