server: format error to json

This commit is contained in:
ngxson 2024-03-09 15:01:19 +01:00
parent 9674aaf35c
commit ba9c3e3192
2 changed files with 113 additions and 56 deletions

View file

@ -1210,15 +1210,23 @@ struct server_context {
}; };
} }
void send_error(const server_task & task, const std::string & error) { void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_SERVER) {
LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); send_error(task.id, task.id_multi, error, type);
}
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_SERVER) {
send_error(slot.id_task, slot.id_multi, error, type);
}
void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_SERVER) {
LOG_TEE("task %i - error: %s\n", id_task, error.c_str());
server_task_result res; server_task_result res;
res.id = task.id; res.id = id_task;
res.id_multi = task.id_multi; res.id_multi = id_multi;
res.stop = false; res.stop = false;
res.error = true; res.error = true;
res.data = { { "content", error } }; res.data = format_error_response(error, type);
queue_results.send(res); queue_results.send(res);
} }
@ -1450,7 +1458,7 @@ struct server_context {
if (!launch_slot_with_data(*slot, task.data)) { if (!launch_slot_with_data(*slot, task.data)) {
// send error result // send error result
send_error(task, "internal_error"); send_error(task, "Cannot launch slot");
break; break;
} }
} break; } break;
@ -1567,7 +1575,7 @@ struct server_context {
queue_results.send(result); queue_results.send(result);
} }
bool update_slots() { void run_slots() {
if (system_need_update) { if (system_need_update) {
system_prompt_update(); system_prompt_update();
} }
@ -1610,7 +1618,7 @@ struct server_context {
kv_cache_clear(); kv_cache_clear();
} }
return true; return;
} }
} }
@ -1955,8 +1963,7 @@ struct server_context {
if (batch.n_tokens == 0) { if (batch.n_tokens == 0) {
LOG_VERBOSE("no tokens to decode", {}); LOG_VERBOSE("no tokens to decode", {});
return;
return true;
} }
LOG_VERBOSE("decoding batch", { LOG_VERBOSE("decoding batch", {
@ -2013,7 +2020,11 @@ struct server_context {
if (n_batch == 1 || ret < 0) { if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size // if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return false; for (auto & slot : slots) {
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
slot.release();
}
break; // break loop of n_batch
} }
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
@ -2022,12 +2033,12 @@ struct server_context {
n_batch /= 2; n_batch /= 2;
i -= n_batch; i -= n_batch;
continue; continue; // continue loop of n_batch
} }
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
continue; continue; // continue loop of slots
} }
// prompt evaluated for embedding // prompt evaluated for embedding
@ -2035,7 +2046,7 @@ struct server_context {
send_embedding(slot, batch_view); send_embedding(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
continue; continue; // continue loop of slots
} }
completion_token_output result; completion_token_output result;
@ -2077,9 +2088,7 @@ struct server_context {
} }
} }
LOG_VERBOSE("slots updated", {}); LOG_VERBOSE("run slots completed", {});
return true;
} }
json model_meta() const { json model_meta() const {
@ -2716,32 +2725,32 @@ int main(int argc, char ** argv) {
svr->set_logger(log_server_request); svr->set_logger(log_server_request);
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { auto res_error = [](httplib::Response & res, json error_data) {
const char fmt[] = "500 Internal Server Error\n%s"; json final_response {{"error", error_data}};
res.set_content(final_response.dump(), "application/json; charset=utf-8");
res.status = json_value(error_data, "code", 500);
};
char buf[BUFSIZ]; svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
std::string message;
try { try {
std::rethrow_exception(std::move(ep)); std::rethrow_exception(std::move(ep));
} catch (std::exception &e) { } catch (std::exception & e) {
snprintf(buf, sizeof(buf), fmt, e.what()); message = e.what();
} catch (...) { } catch (...) {
snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); message = "Unknown Exception";
} }
res.set_content(buf, "text/plain; charset=utf-8"); json formatted_error = format_error_response(message, ERROR_SERVER);
res.status = 500; LOG_VERBOSE("Got exception", formatted_error);
res_error(res, formatted_error);
}); });
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
if (res.status == 401) {
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400) {
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
if (res.status == 404) { if (res.status == 404) {
res.set_content("File Not Found", "text/plain; charset=utf-8"); res_error(res, format_error_response("File Not Found", ERROR_NOT_FOUND));
} }
// for other error codes, we skip processing here because it's already done by res_error()
}); });
// set timeouts and change hostname and port // set timeouts and change hostname and port
@ -2789,7 +2798,7 @@ int main(int argc, char ** argv) {
// Middlewares // Middlewares
// //
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) { auto middleware_validate_api_key = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"? // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::set<std::string> protected_endpoints = { static const std::set<std::string> protected_endpoints = {
"/props", "/props",
@ -2830,8 +2839,7 @@ int main(int argc, char ** argv) {
// API key is invalid or not provided // API key is invalid or not provided
// TODO: make another middleware for CORS related logic // TODO: make another middleware for CORS related logic
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); res_error(res, format_error_response("Invalid API Key", ERROR_AUTHENTICATION));
res.status = 401; // Unauthorized
LOG_WARNING("Unauthorized: Invalid API Key", {}); LOG_WARNING("Unauthorized: Invalid API Key", {});
@ -2894,21 +2902,18 @@ int main(int argc, char ** argv) {
} }
case SERVER_STATE_LOADING_MODEL: case SERVER_STATE_LOADING_MODEL:
{ {
res.set_content(R"({"status": "loading model"})", "application/json"); res_error(res, format_error_response("Loading model", ERROR_UNAVAILABLE));
res.status = 503; // HTTP Service Unavailable
} break; } break;
case SERVER_STATE_ERROR: case SERVER_STATE_ERROR:
{ {
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); res_error(res, format_error_response("Model failed to load", ERROR_SERVER));
res.status = 500; // HTTP Internal Server Error
} break; } break;
} }
}; };
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.slots_endpoint) { if (!sparams.slots_endpoint) {
res.status = 501; res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_NOT_SUPPORTED));
res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
return; return;
} }
@ -2932,8 +2937,7 @@ int main(int argc, char ** argv) {
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.metrics_endpoint) { if (!sparams.metrics_endpoint) {
res.status = 501; res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_NOT_SUPPORTED));
res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
return; return;
} }
@ -3044,7 +3048,7 @@ int main(int argc, char ** argv) {
res.set_content(data.dump(), "application/json; charset=utf-8"); res.set_content(data.dump(), "application/json; charset=utf-8");
}; };
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3059,8 +3063,7 @@ int main(int argc, char ** argv) {
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
} else { } else {
res.status = 500; res_error(res, result.data);
res.set_content(result.data["content"], "text/plain; charset=utf-8");
} }
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
@ -3140,7 +3143,7 @@ int main(int argc, char ** argv) {
res.set_content(models.dump(), "application/json; charset=utf-8"); res.set_content(models.dump(), "application/json; charset=utf-8");
}; };
const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
@ -3157,8 +3160,7 @@ int main(int argc, char ** argv) {
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
} else { } else {
res.status = 500; res_error(res, result.data);
res.set_content(result.data["content"], "text/plain; charset=utf-8");
} }
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
} else { } else {
@ -3212,7 +3214,7 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3227,8 +3229,7 @@ int main(int argc, char ** argv) {
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
} else { } else {
res.status = 404; res_error(res, result.data);
res.set_content(result.data["content"], "text/plain; charset=utf-8");
} }
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
@ -3299,7 +3300,7 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), "application/json; charset=utf-8");
}; };
const auto handle_embeddings = [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) { if (!params.embedding) {
res.status = 501; res.status = 501;
@ -3345,7 +3346,12 @@ int main(int argc, char ** argv) {
// get the result // get the result
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
responses.push_back(result.data); if (!result.error) {
responses.push_back(result.data);
} else {
res_error(res, result.data);
return;
}
} }
// write JSON response // write JSON response
@ -3440,7 +3446,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.on_finish_multitask(std::bind( ctx_server.queue_tasks.on_finish_multitask(std::bind(
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_run_slots(std::bind( ctx_server.queue_tasks.on_run_slots(std::bind(
&server_context::update_slots, &ctx_server)); &server_context::run_slots, &ctx_server));
ctx_server.queue_results.on_multitask_update(std::bind( ctx_server.queue_results.on_multitask_update(std::bind(
&server_queue::update_multitask, &server_queue::update_multitask,
&ctx_server.queue_tasks, &ctx_server.queue_tasks,

View file

@ -14,6 +14,17 @@
using json = nlohmann::json; using json = nlohmann::json;
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_INVALID_REQUEST,
ERROR_AUTHENTICATION,
ERROR_SERVER,
ERROR_NOT_FOUND,
ERROR_PERMISSION,
ERROR_UNAVAILABLE, // custom error
ERROR_NOT_SUPPORTED, // custom error
};
extern bool server_verbose; extern bool server_verbose;
extern bool server_log_json; extern bool server_log_json;
@ -542,3 +553,43 @@ static json format_detokenized_response(const std::string & content) {
{"content", content} {"content", content}
}; };
} }
static json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
int code = 500;
switch (type) {
case ERROR_INVALID_REQUEST:
type_str = "invalid_request_error";
code = 400;
break;
case ERROR_AUTHENTICATION:
type_str = "authentication_error";
code = 401;
break;
case ERROR_NOT_FOUND:
type_str = "not_found_error";
code = 404;
break;
case ERROR_SERVER:
type_str = "server_error";
code = 500;
break;
case ERROR_PERMISSION:
type_str = "permission_error";
code = 403;
break;
case ERROR_NOT_SUPPORTED:
type_str = "not_supported_error";
code = 501;
break;
case ERROR_UNAVAILABLE:
type_str = "unavailable_error";
code = 503;
break;
}
return json {
{"code", code},
{"message", message},
{"type", type_str},
};
}