minor fixes

This commit is contained in:
ngxson 2024-03-24 21:28:12 +01:00
parent f73c470980
commit 950db2bd77

View file

@ -370,8 +370,10 @@ static json oaicompat_completion_params_parse(
} else { } else {
llama_params["stop"] = json_value(body, "stop", json::array()); llama_params["stop"] = json_value(body, "stop", json::array());
} }
// Ensure there is ChatML-specific end sequence among stop words // Some chat templates don't use EOS token to stop generation
llama_params["stop"].push_back("<|im_end|>"); // We must add their end sequences among stop words
llama_params["stop"].push_back("<|im_end|>"); // chatml
llama_params["stop"].push_back("<end_of_turn>"); // gemma
// Handle "response_format" field // Handle "response_format" field
if (body.contains("response_format")) { if (body.contains("response_format")) {
@ -379,23 +381,28 @@ static json oaicompat_completion_params_parse(
std::string response_type = json_value(response_format, "type", std::string()); std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") { if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object()); llama_params["json_schema"] = json_value(response_format, "schema", json::object());
} else if (!response_type.empty()) { } else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type not supported: " + response_type); throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
} }
} }
// Handle "n" field // Handle "n" field
int n_choices = json_value(body, "n", 1); int n_choices = json_value(body, "n", 1);
if (n_choices != 1) { if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is supported"); throw std::runtime_error("Only one completion choice is allowed");
}
// Handle "logprobs" field
if (body.contains("logprobs")) {
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs")) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
} }
// Params supported by OAI but unsupported by llama.cpp // Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params{ static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
"logprobs", "top_logprobs", "tools", "tool_choice"
};
for (auto & param : unsupported_params) { for (auto & param : unsupported_params) {
if (llama_params.contains(param)) { if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param); throw std::runtime_error("Unsupported param: " + param);
} }
} }
@ -404,7 +411,7 @@ static json oaicompat_completion_params_parse(
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
for (const auto & item : body.items()) { for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified by "max_tokens" // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") { if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value(); llama_params[item.key()] = item.value();
} }