From 4840c4e67887c22e269b176d5be88506235cd339 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sat, 9 Mar 2024 21:03:18 +0100 Subject: [PATCH] correct coding style --- examples/server/server.cpp | 39 +++++++++++++++++++------------------- examples/server/utils.hpp | 28 +++++++++++++-------------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e272fca15..38e60f9a1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -798,7 +798,7 @@ struct server_context { return last_used; } - bool launch_slot_with_data(server_slot & slot, json data, std::string & error_message) const { + bool launch_slot_with_data(server_slot & slot, json data) { slot_params default_params; llama_sampling_params default_sparams; @@ -857,7 +857,8 @@ struct server_context { { const auto & prompt = data.find("prompt"); if (prompt == data.end()) { - slot.prompt = ""; + send_error(slot, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); + return false; } else { slot.prompt = *prompt; } @@ -981,7 +982,7 @@ struct server_context { slot.ctx_sampling = llama_sampling_init(slot.sparams); if (slot.ctx_sampling == nullptr) { // for now, the only error that may happen here is invalid grammar - error_message = "Failed to parse grammar"; + send_error(slot, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } llama_set_rng_seed(ctx, slot.params.seed); @@ -1225,15 +1226,15 @@ struct server_context { }; } - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_SERVER) { + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(task.id, task.id_multi, error, type); } - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_SERVER) { + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(slot.id_task, slot.id_multi, error, type); } - void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_SERVER) { + void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { LOG_TEE("task %i - error: %s\n", id_task, error.c_str()); server_task_result res; @@ -1475,10 +1476,8 @@ struct server_context { slot->infill = task.infill; slot->embedding = task.embedding; - std::string err_launch_slot; - if (!launch_slot_with_data(*slot, task.data, err_launch_slot)) { - // send error result - send_error(task, err_launch_slot.empty() ? "Cannot launch slot" : err_launch_slot); + if (!launch_slot_with_data(*slot, task.data)) { + LOG_ERROR("error while launching slot", task.data); break; } } break; @@ -2041,8 +2040,10 @@ struct server_context { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); for (auto & slot : slots) { - send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; slot.release(); + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } break; // break loop of n_batch } @@ -2761,14 +2762,14 @@ int main(int argc, char ** argv) { message = "Unknown Exception"; } - json formatted_error = format_error_response(message, ERROR_SERVER); + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); LOG_VERBOSE("Got exception", formatted_error); res_error(res, formatted_error); }); svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { - res_error(res, format_error_response("File Not Found", ERROR_NOT_FOUND)); + res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } // for other error codes, we skip processing here because it's already done by res_error() }); @@ -2859,7 +2860,7 @@ int main(int argc, char ** argv) { // API key is invalid or not provided // TODO: make another middleware for CORS related logic res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - res_error(res, format_error_response("Invalid API Key", ERROR_AUTHENTICATION)); + res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WARNING("Unauthorized: Invalid API Key", {}); @@ -2922,18 +2923,18 @@ int main(int argc, char ** argv) { } case SERVER_STATE_LOADING_MODEL: { - res_error(res, format_error_response("Loading model", ERROR_UNAVAILABLE)); + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } break; case SERVER_STATE_ERROR: { - res_error(res, format_error_response("Model failed to load", ERROR_SERVER)); + res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); } break; } }; const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { if (!sparams.slots_endpoint) { - res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2957,7 +2958,7 @@ int main(int argc, char ** argv) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!sparams.metrics_endpoint) { - res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -3349,7 +3350,7 @@ int main(int argc, char ** argv) { std::string content = body["content"]; prompts.push_back(content); } else { - res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_INVALID_REQUEST)); + res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b53f6751e..0f7cf3d1b 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -16,13 +16,13 @@ using json = nlohmann::json; // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { - ERROR_INVALID_REQUEST, - ERROR_AUTHENTICATION, - ERROR_SERVER, - ERROR_NOT_FOUND, - ERROR_PERMISSION, - ERROR_UNAVAILABLE, // custom error - ERROR_NOT_SUPPORTED, // custom error + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error }; extern bool server_verbose; @@ -558,31 +558,31 @@ static json format_error_response(const std::string & message, const enum error_ std::string type_str; int code = 500; switch (type) { - case ERROR_INVALID_REQUEST: + case ERROR_TYPE_INVALID_REQUEST: type_str = "invalid_request_error"; code = 400; break; - case ERROR_AUTHENTICATION: + case ERROR_TYPE_AUTHENTICATION: type_str = "authentication_error"; code = 401; break; - case ERROR_NOT_FOUND: + case ERROR_TYPE_NOT_FOUND: type_str = "not_found_error"; code = 404; break; - case ERROR_SERVER: + case ERROR_TYPE_SERVER: type_str = "server_error"; code = 500; break; - case ERROR_PERMISSION: + case ERROR_TYPE_PERMISSION: type_str = "permission_error"; code = 403; break; - case ERROR_NOT_SUPPORTED: + case ERROR_TYPE_NOT_SUPPORTED: type_str = "not_supported_error"; code = 501; break; - case ERROR_UNAVAILABLE: + case ERROR_TYPE_UNAVAILABLE: type_str = "unavailable_error"; code = 503; break;