server : refactor system prompt update at start

This commit is contained in:
Georgi Gerganov 2024-03-05 19:55:19 +02:00
parent 4a2d5f63f2
commit 61b63705dc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 28 additions and 21 deletions

View file

@ -115,6 +115,7 @@ struct server_params {
std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public";
std::string chat_template = "";
std::string system_prompt = "";
std::vector<std::string> api_keys;
@ -1024,7 +1025,7 @@ struct llama_server_context {
system_need_update = false;
}
void system_prompt_process(const json & sys_props) {
void system_prompt_set(const json & sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
@ -1418,7 +1419,7 @@ struct llama_server_context {
}
if (task.data.contains("system_prompt")) {
system_prompt_process(task.data["system_prompt"]);
system_prompt_set(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot & slot : slots) {
@ -1974,7 +1975,10 @@ struct llama_server_context {
}
for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
result.probs.push_back({
cur_p.data[i].id,
cur_p.data[i].p
});
}
if (!process_token(result, slot)) {
@ -2088,7 +2092,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf("\n");
}
static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params, llama_server_context & llama) {
static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params) {
gpt_params default_params;
server_params default_sparams;
@ -2402,13 +2406,13 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
std::string systm_content;
std::string system_prompt;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
std::back_inserter(system_prompt)
);
llama.system_prompt_process(json::parse(systm_content));
sparams.system_prompt = system_prompt;
} else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") {
@ -2506,18 +2510,6 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
}
}
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(const std::string & content) {
return json {
{"content", content}
};
}
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
// skip GH copilot requests when using default port
if (req.path == "/v1/health" || req.path == "/v1/completions") {
@ -2558,13 +2550,17 @@ int main(int argc, char ** argv) {
log_disable();
#endif
// own arguments required by this example
gpt_params params;
gpt_params params;
server_params sparams;
// struct that contains llama context and inference
llama_server_context llama;
server_params_parse(argc, argv, sparams, params, llama);
server_params_parse(argc, argv, sparams, params);
if (!sparams.system_prompt.empty()) {
llama.system_prompt_set(json::parse(sparams.system_prompt));
}
if (params.model_alias == "unknown") {
params.model_alias = params.model;

View file

@ -531,3 +531,14 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res;
}
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(const std::string & content) {
return json {
{"content", content}
};
}