server: clean up oai parsing function
This commit is contained in:
parent
ea279d5609
commit
faaec65fdb
2 changed files with 82 additions and 36 deletions
|
@ -847,9 +847,16 @@ struct server_context {
|
||||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||||
slot.params.seed = json_value(data, "seed", default_params.seed);
|
slot.params.seed = json_value(data, "seed", default_params.seed);
|
||||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
|
||||||
|
// process "json_schema" and "grammar"
|
||||||
|
if (data.contains("json_schema") && data.contains("grammar")) {
|
||||||
|
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
} else if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
try {
|
try {
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
@ -858,8 +865,6 @@ struct server_context {
|
||||||
} else {
|
} else {
|
||||||
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
}
|
}
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
||||||
|
|
||||||
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
||||||
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
||||||
|
|
|
@ -49,6 +49,34 @@ extern bool server_log_json;
|
||||||
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
|
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||||
|
|
||||||
|
// GRAMMAR_JSON is used by OAI "response_format" field
|
||||||
|
static const std::string GRAMMAR_JSON = R"(root ::= object
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\\x7F\x00-\x1F] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
||||||
|
)";
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T json_value(const json &body, const std::string &key, const T &default_value) {
|
static T json_value(const json &body, const std::string &key, const T &default_value) {
|
||||||
// Fallback null to default value
|
// Fallback null to default value
|
||||||
|
@ -352,52 +380,65 @@ static json oaicompat_completion_params_parse(
|
||||||
// https://platform.openai.com/docs/api-reference/chat/create
|
// https://platform.openai.com/docs/api-reference/chat/create
|
||||||
llama_sampling_params default_sparams;
|
llama_sampling_params default_sparams;
|
||||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
|
||||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
|
||||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
|
||||||
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
|
||||||
llama_params["top_p"] = json_value(body, "top_p", 1.0);
|
|
||||||
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
|
|
||||||
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
|
|
||||||
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
|
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
|
||||||
|
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
|
||||||
|
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
|
||||||
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
|
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
|
||||||
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
|
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
|
||||||
llama_params["stream"] = json_value(body, "stream", false);
|
llama_params["stream"] = json_value(body, "stream", false);
|
||||||
llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
|
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||||
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
llama_params["top_p"] = json_value(body, "top_p", 1.0);
|
||||||
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
|
||||||
llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
|
||||||
llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
|
|
||||||
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
|
|
||||||
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
|
|
||||||
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
|
|
||||||
llama_params["n_keep"] = json_value(body, "n_keep", 0);
|
|
||||||
|
|
||||||
if (body.contains("grammar")) {
|
// Apply chat template to the list of messages
|
||||||
llama_params["grammar"] = json_value(body, "grammar", json::object());
|
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
||||||
}
|
|
||||||
|
|
||||||
if (body.contains("response_format")) {
|
// Handle "stop" field
|
||||||
auto response_format = json_value(body, "response_format", json::object());
|
|
||||||
if (response_format.contains("type")) {
|
|
||||||
if (response_format["type"] == "json_object") {
|
|
||||||
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle 'stop' field
|
|
||||||
if (body.contains("stop") && body["stop"].is_string()) {
|
if (body.contains("stop") && body["stop"].is_string()) {
|
||||||
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
||||||
} else {
|
} else {
|
||||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure there is ChatML-specific end sequence among stop words
|
// Ensure there is ChatML-specific end sequence among stop words
|
||||||
llama_params["stop"].push_back("<|im_end|>");
|
llama_params["stop"].push_back("<|im_end|>");
|
||||||
|
|
||||||
|
// Handle "response_format" field
|
||||||
|
if (body.contains("response_format")) {
|
||||||
|
json response_format = json_value(body, "response_format", json::object());
|
||||||
|
std::string response_type = json_value(response_format, "type", std::string());
|
||||||
|
if (response_type == "json_object") {
|
||||||
|
// "json_object" guarantees the message the model generates is valid JSON.
|
||||||
|
llama_params["grammar"] = GRAMMAR_JSON;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("response_format type not supported: " + response_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "n" field
|
||||||
|
int n_choices = json_value(body, "n", 1);
|
||||||
|
if (n_choices != 1) {
|
||||||
|
throw std::runtime_error("Only one completion choice is supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
|
static const std::vector<std::string> unsupported_params{
|
||||||
|
"logprobs", "top_logprobs", "tools", "tool_choice"
|
||||||
|
};
|
||||||
|
for (auto & param : unsupported_params) {
|
||||||
|
if (llama_params.contains(param)) {
|
||||||
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy remaining properties to llama_params
|
||||||
|
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
|
||||||
|
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
||||||
|
for (const auto & item : body.items()) {
|
||||||
|
// Exception: if "n_predict" is present, we overwrite the value specified by "max_tokens"
|
||||||
|
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
|
||||||
|
llama_params[item.key()] = item.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue