preliminary method for general chat format handling in server

This commit is contained in:
Peter Nagymathe 2023-12-29 12:13:42 +00:00
parent 65e5f6dadb
commit 33b8159ef0

View file

@ -44,6 +44,12 @@ struct server_params
int32_t write_timeout = 600;
};
struct chat_formatter
{
std::function<std::string(std::vector<json>)> conversion;
std::vector<std::string> stop_keys;
};
static bool server_verbose = false;
#if SERVER_VERBOSE != 1
@ -63,7 +69,7 @@ static bool server_verbose = false;
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
json oaicompat_completion_params_parse(const json &body);
json oaicompat_completion_params_parse(const json &body, const chat_formatter formatter);
std::string format_chatml(std::vector<json> messages);
@ -514,6 +520,7 @@ struct llama_server_context
clip_ctx *clp_ctx = nullptr;
gpt_params params;
json formatter_params = nullptr;
llama_batch batch;
@ -2358,6 +2365,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
);
llama.process_system_prompt_data(json::parse(systm_content));
}
else if (arg == "-cfmt" || arg == "--chat-format-file"){
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
llama.formatter_params = json::parse(file);
fprintf(stdout, "Loaded chat format '%s'\n", llama.formatter_params.dump().c_str());
}
else if(arg == "--mmproj")
{
if (++i >= argc)
@ -2412,6 +2434,25 @@ static std::string gen_chatcmplid()
return chatcmplid.str();
}
std::string format_chat(std::vector<json> messages, json &formatter)
{
std::ostringstream prompt;
for (auto it = messages.begin(); it != messages.end(); ++it) {
auto role = json_value(*it, "role", std::string("unknown"));
LOG_TEE("\n Message received: %s\nWith role: %s", (*it).dump().c_str(), role.c_str());
auto prefix = formatter.at(role).value("prefix", "");
auto suffix = formatter.at(role).value("suffix", "");
prompt << prefix
<< json_value(*it, "content", std::string(""))
<< suffix
<< json_value(formatter, "delimiter", std::string(" "));
}
prompt << formatter.at("assistant").value("prefix", "");
return prompt.str();
}
std::string format_chatml(std::vector<json> messages)
{
std::ostringstream chatml_msgs;
@ -2430,7 +2471,9 @@ std::string format_chatml(std::vector<json> messages)
/* llama.cpp completion api semantics */
json oaicompat_completion_params_parse(
const json &body /* openai api json semantics */)
const json &body, /* openai api json semantics */
const chat_formatter formatter
)
{
json llama_params;
@ -2445,7 +2488,8 @@ json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("uknown"));
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
// llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
llama_params["prompt"] = formatter.conversion(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
@ -2899,7 +2943,48 @@ int main(int argc, char **argv)
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(json::parse(req.body));
auto formatter_json = llama.formatter_params;
chat_formatter formatter;
if (formatter_json == nullptr){
formatter.conversion = [](std::vector<json> messages){return format_chatml(messages);};
formatter.stop_keys = std::vector<std::string>(1, "<|im_end|>");
}
else
{
formatter.conversion = [&formatter_json](std::vector<json> messages){return format_chat(messages, formatter_json);};
// need json parsing the keys
// fprintf(stdout, "\n attempt to get json keys");
auto items = formatter_json.items();
formatter.stop_keys = std::vector<std::string>(2);
for(auto it = items.begin(); it != items.end(); ++it){
if((*it).key() == "delimiter"){
continue;
// fprintf(stdout, "\n skipped delimiter");
}
// We try to get stop token and store non-duplicate ones
auto stop_token = (*it).value().value("suffix", "");
if(stop_token != "")
{
// fprintf(stdout, "\n try add stop token %s", stop_token.c_str());
if(std::find(formatter.stop_keys.begin(), formatter.stop_keys.end(), stop_token) == formatter.stop_keys.end())
{
// fprintf(stdout, "\n add stop token %s", stop_token.c_str());
formatter.stop_keys.push_back(stop_token);
}
}
}
auto r = std::accumulate(
formatter.stop_keys.begin() + 1,
formatter.stop_keys.end(),
formatter.stop_keys[0],
[](const std::string a, const std::string b){ return a + ", " + b; }
);
fprintf(stdout, "\n Stop Keys: %s", r.c_str());
}
json data = oaicompat_completion_params_parse(json::parse(req.body), formatter);
// json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false, -1);