From 8a8aaee714dd4f200e4534995ebd28934c6ab08c Mon Sep 17 00:00:00 2001 From: ZXED Date: Fri, 8 Mar 2024 20:26:27 +0300 Subject: [PATCH] server: error handling: fixes after merge --- examples/server/server.cpp | 45 +++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index de247f12f..06707bb51 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1211,6 +1211,28 @@ struct server_context { queue_results.send(res); } + static json error_to_json(const llama_error& error) + { + return { + { "error", { + { "id", error.id() }, + { "description", error.description() } + } } + }; + } + + void send_error(const server_task & task, const llama_error& error) + { + LOG_TEE("task %i - error: %s - %s\n", task.id, error.id().c_str(), error.description().c_str()); + server_task_result res; + res.id = task.id; + res.id_multi = task.id_multi; + res.stop = false; + res.error = true; + res.data = { { "content", error_to_json(error).dump() } }; + queue_results.send(res); + } + void send_partial_response(server_slot & slot, completion_token_output tkn) { server_task_result res; res.id = slot.id_task; @@ -1436,10 +1458,15 @@ struct server_context { slot->infill = task.infill; slot->embedding = task.embedding; - if (!launch_slot_with_data(*slot, task.data)) { - // send error result - send_error(task, "internal_error"); - break; + try { + if (!launch_slot_with_data(*slot, task.data)) + { + // send error result + send_error(task, "internal_error"); + break; + } + } catch (const llama_error & err) { + send_error(task, err); } } break; case SERVER_TASK_TYPE_CANCEL: @@ -2953,7 +2980,15 @@ int main(int argc, char ** argv) { return; } - json data = json::parse(req.body); + json data; + try { + data = json::parse(req.body); + } catch (const json::exception & json_err) { + const auto err = llama_error("request.invalid_json", std::string("Invalid JSON: ") + json_err.what()); + const auto err_json = server_context::error_to_json(err).dump(); + res.set_content(err_json, "text/plain; charset=utf-8"); + return; + } const int id_task = ctx_server.queue_tasks.get_new_id();