Merge 9c3e7e0c77
into 4f56458d34
This commit is contained in:
commit
9b59fb66e6
1 changed files with 86 additions and 5 deletions
|
@ -45,6 +45,12 @@ struct server_params
|
|||
int32_t write_timeout = 600;
|
||||
};
|
||||
|
||||
struct chat_formatter
|
||||
{
|
||||
std::function<std::string(std::vector<json>)> conversion;
|
||||
std::vector<std::string> stop_keys;
|
||||
};
|
||||
|
||||
static bool server_verbose = false;
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
|
@ -64,7 +70,7 @@ static bool server_verbose = false;
|
|||
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
|
||||
json oaicompat_completion_params_parse(const json &body);
|
||||
json oaicompat_completion_params_parse(const json &body, const chat_formatter formatter);
|
||||
std::string format_chatml(std::vector<json> messages);
|
||||
|
||||
|
||||
|
@ -524,6 +530,7 @@ struct llama_server_context
|
|||
clip_ctx *clp_ctx = nullptr;
|
||||
|
||||
gpt_params params;
|
||||
json formatter_params = nullptr;
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
|
@ -2376,6 +2383,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
);
|
||||
llama.process_system_prompt_data(json::parse(systm_content));
|
||||
}
|
||||
else if (arg == "-cfmt" || arg == "--chat-format-file"){
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::ifstream file(argv[i]);
|
||||
if (!file) {
|
||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
llama.formatter_params = json::parse(file);
|
||||
fprintf(stdout, "Loaded chat template '%s'\n", llama.formatter_params.dump().c_str());
|
||||
}
|
||||
else if(arg == "--mmproj")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -2477,6 +2499,25 @@ static std::string gen_chatcmplid()
|
|||
return chatcmplid.str();
|
||||
}
|
||||
|
||||
std::string format_chat(std::vector<json> messages, json &formatter)
|
||||
{
|
||||
std::ostringstream prompt;
|
||||
|
||||
for (auto it = messages.begin(); it != messages.end(); ++it) {
|
||||
auto role = json_value(*it, "role", std::string("unknown"));
|
||||
LOG_TEE("\n Message received: %s\nWith role: %s", (*it).dump().c_str(), role.c_str());
|
||||
auto prefix = formatter.at(role).value("prefix", "");
|
||||
auto suffix = formatter.at(role).value("suffix", "");
|
||||
prompt << prefix
|
||||
<< json_value(*it, "content", std::string(""))
|
||||
<< suffix
|
||||
<< json_value(formatter, "delimiter", std::string(" "));
|
||||
}
|
||||
prompt << formatter.at("assistant").value("prefix", "");
|
||||
|
||||
return prompt.str();
|
||||
}
|
||||
|
||||
std::string format_chatml(std::vector<json> messages)
|
||||
{
|
||||
std::ostringstream chatml_msgs;
|
||||
|
@ -2495,7 +2536,9 @@ std::string format_chatml(std::vector<json> messages)
|
|||
|
||||
/* llama.cpp completion api semantics */
|
||||
json oaicompat_completion_params_parse(
|
||||
const json &body /* openai api json semantics */)
|
||||
const json &body, /* openai api json semantics */
|
||||
const chat_formatter formatter
|
||||
)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
|
@ -2510,7 +2553,8 @@ json oaicompat_completion_params_parse(
|
|||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
llama_sampling_params default_sparams;
|
||||
llama_params["model"] = json_value(body, "model", std::string("uknown"));
|
||||
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
// llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
llama_params["prompt"] = formatter.conversion(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
||||
|
@ -2542,7 +2586,9 @@ json oaicompat_completion_params_parse(
|
|||
}
|
||||
|
||||
// Ensure there is ChatML-specific end sequence among stop words
|
||||
llama_params["stop"].push_back("<|im_end|>");
|
||||
for(auto& stop : formatter.stop_keys){
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
|
||||
return llama_params;
|
||||
}
|
||||
|
@ -2964,7 +3010,42 @@ int main(int argc, char **argv)
|
|||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||
|
||||
auto formatter_json = llama.formatter_params;
|
||||
chat_formatter formatter;
|
||||
if (formatter_json == nullptr){
|
||||
formatter.conversion = [](std::vector<json> messages){return format_chatml(messages);};
|
||||
formatter.stop_keys = std::vector<std::string>(1, "<|im_end|>");
|
||||
}
|
||||
else
|
||||
{
|
||||
formatter.conversion = [&formatter_json](std::vector<json> messages){return format_chat(messages, formatter_json);};
|
||||
// need json parsing the keys
|
||||
auto items = formatter_json.items();
|
||||
formatter.stop_keys = std::vector<std::string>(2);
|
||||
for(auto it = items.begin(); it != items.end(); ++it){
|
||||
if((*it).key() == "delimiter"){
|
||||
continue;
|
||||
}
|
||||
// We try to get stop token and store non-duplicate ones
|
||||
auto stop_token = (*it).value().value("suffix", "");
|
||||
if(stop_token != "")
|
||||
{
|
||||
if(std::find(formatter.stop_keys.begin(), formatter.stop_keys.end(), stop_token) == formatter.stop_keys.end())
|
||||
{
|
||||
formatter.stop_keys.push_back(stop_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto r = std::accumulate(
|
||||
formatter.stop_keys.begin() + 1,
|
||||
formatter.stop_keys.end(),
|
||||
formatter.stop_keys[0],
|
||||
[](const std::string a, const std::string b){ return a + ", " + b; }
|
||||
);
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body), formatter);
|
||||
|
||||
const int task_id = llama.request_completion(data, false, false, -1);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue