diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d907338a5..0c219c93f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -5,13 +5,6 @@ #include "llama.h" #include "grammar-parser.h" -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif -// increase max payload length to allow use of larger context size -#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 -#include "httplib.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" @@ -39,12 +32,12 @@ #include #include #include -#include #include #include #include #include #include +#include using json = nlohmann::ordered_json; @@ -97,6 +90,15 @@ struct server_task { json data; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + + // utility function + static std::vector get_list_id(std::vector tasks) { + std::vector ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids[i] = tasks[i].id; + } + return ids; + } }; struct server_task_result { @@ -108,13 +110,6 @@ struct server_task_result { bool error; }; -struct server_task_multi { - int id = -1; - - std::set subtasks_remaining; - std::vector results; -}; - struct slot_params { bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt @@ -133,6 +128,9 @@ struct server_slot { int id; int id_task = -1; + // the index relative to completion multi-task request + size_t index = 0; + struct slot_params params; slot_state state = SLOT_STATE_IDLE; @@ -388,15 +386,12 @@ struct server_queue { std::vector queue_tasks; std::vector queue_tasks_deferred; - std::vector queue_multitasks; - std::mutex mutex_tasks; std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue int post(server_task task) { @@ -437,11 +432,6 @@ struct server_queue { callback_new_task = std::move(callback); } - // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) { - callback_finish_multitask = std::move(callback); - } - // Register the function to be called when all slots data is ready to be processed void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); @@ -490,22 +480,6 @@ struct server_queue { callback_new_task(task); } - LOG_VERBOSE("update_multitasks", {}); - - // check if we have any finished multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) { - if (queue_iterator->subtasks_remaining.empty()) { - // all subtasks done == multitask is done - server_task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } else { - ++queue_iterator; - } - } - // all tasks in the current loop is processed, slots data is now ready LOG_VERBOSE("callback_update_slots", {}); @@ -530,7 +504,7 @@ struct server_queue { struct server_response { // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; + std::unordered_set waiting_task_ids; // the main result queue std::vector queue_results; @@ -1387,7 +1361,8 @@ struct server_context { {"content", tkn.text_to_send}, {"stop", false}, {"id_slot", slot.id}, - {"multimodal", false} + {"multimodal", false}, + {"index", slot.index}, }; if (slot.sparams.n_probs > 0) { @@ -1434,7 +1409,8 @@ struct server_context { {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()} + {"timings", slot.get_formated_timings()}, + {"index", slot.index}, }; if (slot.sparams.n_probs > 0) { @@ -1500,6 +1476,7 @@ struct server_context { res.data = json { {"embedding", embd_res}, + {"index", slot.index}, }; } @@ -1507,10 +1484,10 @@ struct server_context { } // - // functions to create new task(s) + // Functions to create new task(s) and receive result(s) // - std::vector request_completion(json data, server_task_cmpl_type cmpl_type) { + std::vector create_tasks_completion(json data, server_task_cmpl_type cmpl_type) { std::vector tasks; auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { server_task task; @@ -1535,12 +1512,16 @@ struct server_context { // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task if (prompt.is_string() || json_is_array_of_numbers(prompt)) { + data["index"] = 0; create_task(data, false, nullptr); } // otherwise, it's a multiple-prompt task, we break it into smaller tasks else if (prompt.is_array()) { - for (auto const & e : prompt) { + std::vector prompts = prompt; + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; create_task(data, true, e); } else { throw std::runtime_error(error_msg); @@ -1555,17 +1536,61 @@ struct server_context { return tasks; } - void request_cancel(std::vector & tasks) { - for (const auto & t : tasks) { + void cancel_tasks(std::unordered_set & id_tasks) { + for (const auto & id_task : id_tasks) { + LOG_VERBOSE("cancel task", {{"id_task", id_task}}); server_task task; task.type = SERVER_TASK_TYPE_CANCEL; - task.id_target = t.id; + task.id_target = id_task; queue_tasks.post(task); + queue_results.remove_waiting_task_id(id_task); + } + } + + void receive_cmpl_results(std::vector & id_tasks, std::function&)> result_handler, std::function error_handler) { + std::unordered_set id_set(id_tasks.begin(), id_tasks.end()); + std::vector results(id_set.size()); + for (size_t i = 0; i < id_set.size(); i++) { + server_task_result result = queue_results.recv(id_set); + + if (result.error) { + error_handler(result.data); + cancel_tasks(id_set); + break; + } + + size_t idx = result.data["index"]; + results[idx] = result; + } + result_handler(results); + } + + void receive_cmpl_results_stream(std::vector & id_tasks, std::function result_handler, std::function error_handler) { + std::unordered_set id_set(id_tasks.begin(), id_tasks.end()); + size_t n_finished = 0; + while (true) { + server_task_result result = queue_results.recv(id_set); + if (!result_handler(result)) { + cancel_tasks(id_set); + break; + } + + if (result.error) { + error_handler(result.data); + cancel_tasks(id_set); + break; + } + + if (result.stop) { + if (++n_finished == id_set.size()) { + break; + } + } } } // - // functions to process the task + // Functions to process the task // void process_single_task(const server_task & task) { @@ -1614,6 +1639,7 @@ struct server_context { slot->id_task = task.id; slot->cmpl_type = task.cmpl_type; + slot->index = json_value(task.data, "index", 0); if (!launch_slot_with_task(*slot, task)) { LOG_ERROR("error while launching slot", task.data); @@ -1841,26 +1867,6 @@ struct server_context { } } - void on_finish_multitask(const server_task_multi & multitask) { - // all subtasks done == multitask is done - server_task_result result; - result.id = multitask.id; - result.stop = true; - result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (const auto & subres : multitask.results) { - result_jsons.push_back(subres.data); - result.error = result.error && subres.error; - } - result.data = json { - { "results", result_jsons } - }; - - queue_results.send(result); - } - void update_slots() { if (system_need_update) { system_prompt_update(); @@ -2556,6 +2562,11 @@ int main(int argc, char ** argv) { res.status = json_value(error_data, "code", 500); }; + auto res_ok = [](httplib::Response & res, json data) { + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.status = 200; + }; + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { std::string message; try { @@ -2603,7 +2614,7 @@ int main(int argc, char ** argv) { auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? - static const std::set protected_endpoints = { + static const std::unordered_set protected_endpoints = { "/props", "/completion", "/completions", @@ -2932,81 +2943,106 @@ int main(int argc, char ** argv) { res.set_content(data.dump(), MIMETYPE_JSON); }; - const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = json::parse(req.body); - - std::vector tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL); + std::vector tasks = ctx_server.create_tasks_completion(data, cmpl_type); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); - if (!json_value(data, "stream", false)) { - server_task_result result = ctx_server.queue_results.recv(tasks); - if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } else { - res_error(res, result.data); - } + bool stream = json_value(data, "stream", false); + std::vector task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.remove_waiting_tasks(tasks); - } else { - const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable { - while (true) { - server_task_result result = ctx_server.queue_results.recv(tasks); - if (!result.error) { - const std::string str = - "data: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_tasks(tasks); - return false; - } - - if (result.stop) { - break; - } - } else { - const std::string str = - "error: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_tasks(tasks); - return false; - } - - break; + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); } + res_ok(res, arr); } - - ctx_server.queue_results.remove_waiting_tasks(tasks); + }, [&](json error_data) { + res_error(res, error_data); + }); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) mutable { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + return server_sent_event(sink, "data", result.data); + }, [&](json error_data) { + server_sent_event(sink, "error", error_data); + }); sink.done(); - return true; }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider); + } + }; - auto on_complete = [tasks, &ctx_server](bool) mutable { - // cancel - ctx_server.request_cancel(tasks); - ctx_server.queue_results.remove_waiting_tasks(tasks); + const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); + }; + + const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); + }; + + // TODO: maybe merge this function with "handle_completions_generic" + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + + std::vector tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + bool stream = json_value(data, "stream", false); + std::vector task_ids = server_task::get_list_id(tasks); + const auto completion_id = gen_chatcmplid(); + + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + // multitask is never support in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id); + res_ok(res, result_oai); + }, [&](json error_data) { + res_error(res, error_data); + }); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); + for (auto & event_data : result_array) { + if (event_data.empty()) { + continue; // skip the stop token + } + if (!server_sent_event(sink, "data", event_data)) { + return false; // connection is closed + } + } + return true; // ok + }, [&](json error_data) { + server_sent_event(sink, "error", error_data); + }); + std::string done_event = "[DONE]"; // OAI-compat behavior + sink.write(done_event.c_str(), done_event.size()); + sink.done(); + return true; }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider); } }; @@ -3027,142 +3063,6 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), MIMETYPE_JSON); }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - - std::vector tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - const auto completion_id = gen_chatcmplid(); - if (!json_value(data, "stream", false)) { - server_task_result result = ctx_server.queue_results.recv(tasks); - - if (!result.error && result.stop) { - json result_oai = format_final_response_oaicompat(data, result.data, completion_id); - - res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } else { - res_error(res, result.data); - } - ctx_server.queue_results.remove_waiting_tasks(tasks); - } else { - const auto chunked_content_provider = [tasks, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable { - while (true) { - server_task_result result = ctx_server.queue_results.recv(tasks); - if (!result.error) { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); - - for (auto it = result_array.begin(); it != result_array.end(); ++it) { - if (!it->empty()) { - const std::string str = - "data: " + - it->dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", {{"to_send", str}}); - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_tasks(tasks); - return false; - } - } - } - if (result.stop) { - break; - } - } else { - const std::string str = - "error: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", {{"to_send", str}}); - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_tasks(tasks); - return false; - } - break; - } - } - sink.done(); - ctx_server.queue_results.remove_waiting_tasks(tasks); - return true; - }; - - auto on_complete = [tasks, &ctx_server](bool) mutable { - // cancel request - ctx_server.request_cancel(tasks); - ctx_server.queue_results.remove_waiting_tasks(tasks); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - json data = json::parse(req.body); - - std::vector tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_INFILL); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - if (!json_value(data, "stream", false)) { - server_task_result result = ctx_server.queue_results.recv(tasks); - if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } else { - res_error(res, result.data); - } - - ctx_server.queue_results.remove_waiting_tasks(tasks); - } else { - const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable { - while (true) { - server_task_result result = ctx_server.queue_results.recv(tasks); - if (!result.error) { - const std::string str = - "data: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_tasks(tasks); - return false; - } - - if (result.stop) { - break; - } - } else { - break; - } - } - - ctx_server.queue_results.remove_waiting_tasks(tasks); - sink.done(); - - return true; - }; - - auto on_complete = [tasks, &ctx_server](bool) mutable { - ctx_server.request_cancel(tasks); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); @@ -3208,7 +3108,7 @@ int main(int argc, char ** argv) { // create and queue the task json responses; { - std::vector tasks = ctx_server.request_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_INFILL); + std::vector tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3407,8 +3307,6 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); ctx_server.queue_tasks.on_update_slots(std::bind( &server_context::update_slots, &ctx_server)); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index a8414d3b8..edfce65b6 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -3,6 +3,14 @@ #include "llama.h" #include "common.h" +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +#include "httplib.h" + // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" @@ -355,6 +363,19 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector