preliminary method for general chat format handling in server
This commit is contained in:
parent
65e5f6dadb
commit
33b8159ef0
1 changed files with 89 additions and 4 deletions
|
@ -44,6 +44,12 @@ struct server_params
|
|||
int32_t write_timeout = 600;
|
||||
};
|
||||
|
||||
struct chat_formatter
|
||||
{
|
||||
std::function<std::string(std::vector<json>)> conversion;
|
||||
std::vector<std::string> stop_keys;
|
||||
};
|
||||
|
||||
static bool server_verbose = false;
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
|
@ -63,7 +69,7 @@ static bool server_verbose = false;
|
|||
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
|
||||
json oaicompat_completion_params_parse(const json &body);
|
||||
json oaicompat_completion_params_parse(const json &body, const chat_formatter formatter);
|
||||
std::string format_chatml(std::vector<json> messages);
|
||||
|
||||
|
||||
|
@ -514,6 +520,7 @@ struct llama_server_context
|
|||
clip_ctx *clp_ctx = nullptr;
|
||||
|
||||
gpt_params params;
|
||||
json formatter_params = nullptr;
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
|
@ -2358,6 +2365,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
);
|
||||
llama.process_system_prompt_data(json::parse(systm_content));
|
||||
}
|
||||
else if (arg == "-cfmt" || arg == "--chat-format-file"){
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::ifstream file(argv[i]);
|
||||
if (!file) {
|
||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
llama.formatter_params = json::parse(file);
|
||||
fprintf(stdout, "Loaded chat format '%s'\n", llama.formatter_params.dump().c_str());
|
||||
}
|
||||
else if(arg == "--mmproj")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -2412,6 +2434,25 @@ static std::string gen_chatcmplid()
|
|||
return chatcmplid.str();
|
||||
}
|
||||
|
||||
std::string format_chat(std::vector<json> messages, json &formatter)
|
||||
{
|
||||
std::ostringstream prompt;
|
||||
|
||||
for (auto it = messages.begin(); it != messages.end(); ++it) {
|
||||
auto role = json_value(*it, "role", std::string("unknown"));
|
||||
LOG_TEE("\n Message received: %s\nWith role: %s", (*it).dump().c_str(), role.c_str());
|
||||
auto prefix = formatter.at(role).value("prefix", "");
|
||||
auto suffix = formatter.at(role).value("suffix", "");
|
||||
prompt << prefix
|
||||
<< json_value(*it, "content", std::string(""))
|
||||
<< suffix
|
||||
<< json_value(formatter, "delimiter", std::string(" "));
|
||||
}
|
||||
prompt << formatter.at("assistant").value("prefix", "");
|
||||
|
||||
return prompt.str();
|
||||
}
|
||||
|
||||
std::string format_chatml(std::vector<json> messages)
|
||||
{
|
||||
std::ostringstream chatml_msgs;
|
||||
|
@ -2430,7 +2471,9 @@ std::string format_chatml(std::vector<json> messages)
|
|||
|
||||
/* llama.cpp completion api semantics */
|
||||
json oaicompat_completion_params_parse(
|
||||
const json &body /* openai api json semantics */)
|
||||
const json &body, /* openai api json semantics */
|
||||
const chat_formatter formatter
|
||||
)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
|
@ -2445,7 +2488,8 @@ json oaicompat_completion_params_parse(
|
|||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
llama_sampling_params default_sparams;
|
||||
llama_params["model"] = json_value(body, "model", std::string("uknown"));
|
||||
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
// llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
llama_params["prompt"] = formatter.conversion(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
||||
|
@ -2899,7 +2943,48 @@ int main(int argc, char **argv)
|
|||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||
|
||||
auto formatter_json = llama.formatter_params;
|
||||
chat_formatter formatter;
|
||||
if (formatter_json == nullptr){
|
||||
formatter.conversion = [](std::vector<json> messages){return format_chatml(messages);};
|
||||
formatter.stop_keys = std::vector<std::string>(1, "<|im_end|>");
|
||||
}
|
||||
else
|
||||
{
|
||||
formatter.conversion = [&formatter_json](std::vector<json> messages){return format_chat(messages, formatter_json);};
|
||||
// need json parsing the keys
|
||||
// fprintf(stdout, "\n attempt to get json keys");
|
||||
auto items = formatter_json.items();
|
||||
formatter.stop_keys = std::vector<std::string>(2);
|
||||
for(auto it = items.begin(); it != items.end(); ++it){
|
||||
if((*it).key() == "delimiter"){
|
||||
continue;
|
||||
// fprintf(stdout, "\n skipped delimiter");
|
||||
}
|
||||
// We try to get stop token and store non-duplicate ones
|
||||
auto stop_token = (*it).value().value("suffix", "");
|
||||
if(stop_token != "")
|
||||
{
|
||||
// fprintf(stdout, "\n try add stop token %s", stop_token.c_str());
|
||||
if(std::find(formatter.stop_keys.begin(), formatter.stop_keys.end(), stop_token) == formatter.stop_keys.end())
|
||||
{
|
||||
// fprintf(stdout, "\n add stop token %s", stop_token.c_str());
|
||||
formatter.stop_keys.push_back(stop_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto r = std::accumulate(
|
||||
formatter.stop_keys.begin() + 1,
|
||||
formatter.stop_keys.end(),
|
||||
formatter.stop_keys[0],
|
||||
[](const std::string a, const std::string b){ return a + ", " + b; }
|
||||
);
|
||||
fprintf(stdout, "\n Stop Keys: %s", r.c_str());
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body), formatter);
|
||||
// json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||
|
||||
const int task_id = llama.request_completion(data, false, false, -1);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue