server: add mistral chat template
This commit is contained in:
parent
41f308f58e
commit
2ebedda3d9
3 changed files with 51 additions and 4 deletions
|
@ -15,9 +15,14 @@
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
inline static json oaicompat_completion_params_parse(
|
inline static json oaicompat_completion_params_parse(
|
||||||
const json &body /* openai api json semantics */)
|
const json &body, /* openai api json semantics */
|
||||||
|
const std::string &chat_template)
|
||||||
{
|
{
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
bool using_chatml = chat_template == "chatml";
|
||||||
|
std::string formated_prompt = using_chatml
|
||||||
|
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml
|
||||||
|
: format_mistral(body["messages"]); // OpenAI 'messages' to mistral format
|
||||||
|
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
|
|
||||||
|
@ -30,7 +35,7 @@ inline static json oaicompat_completion_params_parse(
|
||||||
// https://platform.openai.com/docs/api-reference/chat/create
|
// https://platform.openai.com/docs/api-reference/chat/create
|
||||||
llama_sampling_params default_sparams;
|
llama_sampling_params default_sparams;
|
||||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||||
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
llama_params["prompt"] = formated_prompt;
|
||||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||||
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
||||||
|
|
|
@ -36,6 +36,7 @@ struct server_params
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
|
std::string chat_template = "chatml";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
int32_t write_timeout = 600;
|
int32_t write_timeout = 600;
|
||||||
|
@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||||
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
|
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
|
||||||
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
|
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
|
||||||
|
printf(" --chat-template FORMAT_NAME");
|
||||||
|
printf(" set chat template, possible valus is: mistral, chatml (default %s)", sparams.chat_template.c_str());
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2290,6 +2293,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
log_set_target(stdout);
|
log_set_target(stdout);
|
||||||
LOG_INFO("logging to file is disabled.", {});
|
LOG_INFO("logging to file is disabled.", {});
|
||||||
}
|
}
|
||||||
|
else if (arg == "--chat-template")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.chat_template = argv[i];
|
||||||
|
}
|
||||||
else if (arg == "--override-kv")
|
else if (arg == "--override-kv")
|
||||||
{
|
{
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
|
@ -2743,13 +2755,13 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
|
|
||||||
// TODO: add mount point without "/v1" prefix -- how?
|
// TODO: add mount point without "/v1" prefix -- how?
|
||||||
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
if (!validate_api_key(req, res)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
|
||||||
|
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
|
|
@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
||||||
: default_value;
|
: default_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::string format_mistral(std::vector<json> messages)
|
||||||
|
{
|
||||||
|
std::ostringstream output;
|
||||||
|
bool is_inside_turn = false;
|
||||||
|
|
||||||
|
for (auto it = messages.begin(); it != messages.end(); ++it) {
|
||||||
|
if (!is_inside_turn) {
|
||||||
|
output << "<s>[INST] ";
|
||||||
|
}
|
||||||
|
std::string role = json_value(*it, "role", std::string("user"));
|
||||||
|
std::string content = json_value(*it, "content", std::string(""));
|
||||||
|
if (role == "system") {
|
||||||
|
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
|
||||||
|
is_inside_turn = true;
|
||||||
|
} else if (role == "user") {
|
||||||
|
output << content << " [/INST]";
|
||||||
|
is_inside_turn = true;
|
||||||
|
} else {
|
||||||
|
output << " " << content << " </s>";
|
||||||
|
is_inside_turn = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("format_mistral", {{"text", output.str()}});
|
||||||
|
|
||||||
|
return output.str();
|
||||||
|
}
|
||||||
|
|
||||||
inline std::string format_chatml(std::vector<json> messages)
|
inline std::string format_chatml(std::vector<json> messages)
|
||||||
{
|
{
|
||||||
std::ostringstream chatml_msgs;
|
std::ostringstream chatml_msgs;
|
||||||
|
@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector<json> messages)
|
||||||
|
|
||||||
chatml_msgs << "<|im_start|>assistant" << '\n';
|
chatml_msgs << "<|im_start|>assistant" << '\n';
|
||||||
|
|
||||||
|
LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
|
||||||
|
|
||||||
return chatml_msgs.str();
|
return chatml_msgs.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue