correct coding style

This commit is contained in:
ngxson 2024-03-09 21:03:18 +01:00
parent 0f641ebd63
commit 4840c4e678
2 changed files with 34 additions and 33 deletions

View file

@ -798,7 +798,7 @@ struct server_context {
return last_used; return last_used;
} }
bool launch_slot_with_data(server_slot & slot, json data, std::string & error_message) const { bool launch_slot_with_data(server_slot & slot, json data) {
slot_params default_params; slot_params default_params;
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
@ -857,7 +857,8 @@ struct server_context {
{ {
const auto & prompt = data.find("prompt"); const auto & prompt = data.find("prompt");
if (prompt == data.end()) { if (prompt == data.end()) {
slot.prompt = ""; send_error(slot, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
} else { } else {
slot.prompt = *prompt; slot.prompt = *prompt;
} }
@ -981,7 +982,7 @@ struct server_context {
slot.ctx_sampling = llama_sampling_init(slot.sparams); slot.ctx_sampling = llama_sampling_init(slot.sparams);
if (slot.ctx_sampling == nullptr) { if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar // for now, the only error that may happen here is invalid grammar
error_message = "Failed to parse grammar"; send_error(slot, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
llama_set_rng_seed(ctx, slot.params.seed); llama_set_rng_seed(ctx, slot.params.seed);
@ -1225,15 +1226,15 @@ struct server_context {
}; };
} }
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_SERVER) { void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, task.id_multi, error, type); send_error(task.id, task.id_multi, error, type);
} }
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_SERVER) { void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(slot.id_task, slot.id_multi, error, type); send_error(slot.id_task, slot.id_multi, error, type);
} }
void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_SERVER) { void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
LOG_TEE("task %i - error: %s\n", id_task, error.c_str()); LOG_TEE("task %i - error: %s\n", id_task, error.c_str());
server_task_result res; server_task_result res;
@ -1475,10 +1476,8 @@ struct server_context {
slot->infill = task.infill; slot->infill = task.infill;
slot->embedding = task.embedding; slot->embedding = task.embedding;
std::string err_launch_slot; if (!launch_slot_with_data(*slot, task.data)) {
if (!launch_slot_with_data(*slot, task.data, err_launch_slot)) { LOG_ERROR("error while launching slot", task.data);
// send error result
send_error(task, err_launch_slot.empty() ? "Cannot launch slot" : err_launch_slot);
break; break;
} }
} break; } break;
@ -2041,8 +2040,10 @@ struct server_context {
// if you get here, it means the KV cache is full - try increasing it via the context size // if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
for (auto & slot : slots) { for (auto & slot : slots) {
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release(); slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
} }
break; // break loop of n_batch break; // break loop of n_batch
} }
@ -2761,14 +2762,14 @@ int main(int argc, char ** argv) {
message = "Unknown Exception"; message = "Unknown Exception";
} }
json formatted_error = format_error_response(message, ERROR_SERVER); json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_VERBOSE("Got exception", formatted_error); LOG_VERBOSE("Got exception", formatted_error);
res_error(res, formatted_error); res_error(res, formatted_error);
}); });
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) { if (res.status == 404) {
res_error(res, format_error_response("File Not Found", ERROR_NOT_FOUND)); res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
} }
// for other error codes, we skip processing here because it's already done by res_error() // for other error codes, we skip processing here because it's already done by res_error()
}); });
@ -2859,7 +2860,7 @@ int main(int argc, char ** argv) {
// API key is invalid or not provided // API key is invalid or not provided
// TODO: make another middleware for CORS related logic // TODO: make another middleware for CORS related logic
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Invalid API Key", ERROR_AUTHENTICATION)); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WARNING("Unauthorized: Invalid API Key", {}); LOG_WARNING("Unauthorized: Invalid API Key", {});
@ -2922,18 +2923,18 @@ int main(int argc, char ** argv) {
} }
case SERVER_STATE_LOADING_MODEL: case SERVER_STATE_LOADING_MODEL:
{ {
res_error(res, format_error_response("Loading model", ERROR_UNAVAILABLE)); res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} break; } break;
case SERVER_STATE_ERROR: case SERVER_STATE_ERROR:
{ {
res_error(res, format_error_response("Model failed to load", ERROR_SERVER)); res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
} break; } break;
} }
}; };
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.slots_endpoint) { if (!sparams.slots_endpoint) {
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2957,7 +2958,7 @@ int main(int argc, char ** argv) {
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.metrics_endpoint) { if (!sparams.metrics_endpoint) {
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -3349,7 +3350,7 @@ int main(int argc, char ** argv) {
std::string content = body["content"]; std::string content = body["content"];
prompts.push_back(content); prompts.push_back(content);
} else { } else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }

View file

@ -16,13 +16,13 @@ using json = nlohmann::json;
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type { enum error_type {
ERROR_INVALID_REQUEST, ERROR_TYPE_INVALID_REQUEST,
ERROR_AUTHENTICATION, ERROR_TYPE_AUTHENTICATION,
ERROR_SERVER, ERROR_TYPE_SERVER,
ERROR_NOT_FOUND, ERROR_TYPE_NOT_FOUND,
ERROR_PERMISSION, ERROR_TYPE_PERMISSION,
ERROR_UNAVAILABLE, // custom error ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_NOT_SUPPORTED, // custom error ERROR_TYPE_NOT_SUPPORTED, // custom error
}; };
extern bool server_verbose; extern bool server_verbose;
@ -558,31 +558,31 @@ static json format_error_response(const std::string & message, const enum error_
std::string type_str; std::string type_str;
int code = 500; int code = 500;
switch (type) { switch (type) {
case ERROR_INVALID_REQUEST: case ERROR_TYPE_INVALID_REQUEST:
type_str = "invalid_request_error"; type_str = "invalid_request_error";
code = 400; code = 400;
break; break;
case ERROR_AUTHENTICATION: case ERROR_TYPE_AUTHENTICATION:
type_str = "authentication_error"; type_str = "authentication_error";
code = 401; code = 401;
break; break;
case ERROR_NOT_FOUND: case ERROR_TYPE_NOT_FOUND:
type_str = "not_found_error"; type_str = "not_found_error";
code = 404; code = 404;
break; break;
case ERROR_SERVER: case ERROR_TYPE_SERVER:
type_str = "server_error"; type_str = "server_error";
code = 500; code = 500;
break; break;
case ERROR_PERMISSION: case ERROR_TYPE_PERMISSION:
type_str = "permission_error"; type_str = "permission_error";
code = 403; code = 403;
break; break;
case ERROR_NOT_SUPPORTED: case ERROR_TYPE_NOT_SUPPORTED:
type_str = "not_supported_error"; type_str = "not_supported_error";
code = 501; code = 501;
break; break;
case ERROR_UNAVAILABLE: case ERROR_TYPE_UNAVAILABLE:
type_str = "unavailable_error"; type_str = "unavailable_error";
code = 503; code = 503;
break; break;