fix response_format

This commit is contained in:
ngxson 2024-03-24 19:44:52 +01:00
parent faaec65fdb
commit 4aeef9bc31

View file

@ -49,34 +49,6 @@ extern bool server_log_json;
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
// GRAMMAR_JSON is used by OAI "response_format" field
static const std::string GRAMMAR_JSON = R"(root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
)";
template <typename T> template <typename T>
static T json_value(const json &body, const std::string &key, const T &default_value) { static T json_value(const json &body, const std::string &key, const T &default_value) {
// Fallback null to default value // Fallback null to default value
@ -404,10 +376,9 @@ static json oaicompat_completion_params_parse(
// Handle "response_format" field // Handle "response_format" field
if (body.contains("response_format")) { if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object()); json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string()); std::string response_type = json_value(response_format, "type", std::string("unknown"));
if (response_type == "json_object") { if (response_type == "json_object") {
// "json_object" guarantees the message the model generates is valid JSON. llama_params["json_schema"] = json_value(response_format, "schema", json::object());
llama_params["grammar"] = GRAMMAR_JSON;
} else { } else {
throw std::runtime_error("response_format type not supported: " + response_type); throw std::runtime_error("response_format type not supported: " + response_type);
} }