minor fixes

This commit is contained in:
ngxson 2024-03-24 21:28:12 +01:00
parent f73c470980
commit 950db2bd77

View file

@ -370,8 +370,10 @@ static json oaicompat_completion_params_parse(
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Ensure there is ChatML-specific end sequence among stop words
llama_params["stop"].push_back("<|im_end|>");
// Some chat templates don't use EOS token to stop generation
// We must add their end sequences among stop words
llama_params["stop"].push_back("<|im_end|>"); // chatml
llama_params["stop"].push_back("<end_of_turn>"); // gemma
// Handle "response_format" field
if (body.contains("response_format")) {
@ -379,23 +381,28 @@ static json oaicompat_completion_params_parse(
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
} else if (!response_type.empty()) {
throw std::runtime_error("response_format type not supported: " + response_type);
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is supported");
throw std::runtime_error("Only one completion choice is allowed");
}
// Handle "logprobs" field
if (body.contains("logprobs")) {
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs")) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params{
"logprobs", "top_logprobs", "tools", "tool_choice"
};
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
for (auto & param : unsupported_params) {
if (llama_params.contains(param)) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
}
}
@ -404,7 +411,7 @@ static json oaicompat_completion_params_parse(
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified by "max_tokens"
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
}