From 1da67a395cd683469e0397d1496618bcb2725cc2 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 01:08:16 +0100 Subject: [PATCH] `server`: support cancelling non-streamed requests --- examples/server/server.cpp | 336 +++++++++++++++++++++---------------- 1 file changed, 188 insertions(+), 148 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f343cc252..1ce4d7e26 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -33,6 +33,7 @@ #include #include +#include #include #include #include @@ -104,6 +105,7 @@ struct server_task { json data; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + std::function is_alive; // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -173,7 +175,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - + std::function is_alive; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -876,6 +878,7 @@ struct server_context { // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) auto default_sparams = params.sparams; const auto & data = task.data; + slot.is_alive = task.is_alive; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -1117,6 +1120,13 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { + if (!slot.is_alive()) { + slot.truncated = false; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by client disconnection, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + return slot.has_next_token; + } // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; @@ -1461,13 +1471,14 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { + std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type, const std::function & is_alive) { std::vector tasks; auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { server_task task; task.id = queue_tasks.get_new_id(); task.cmpl_type = cmpl_type; task.type = SERVER_TASK_TYPE_COMPLETION; + task.is_alive = is_alive; if (replace_prompt) { task.data = task_data; task.data["prompt"] = std::move(prompt); @@ -2412,6 +2423,60 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } +static void handle_tasks( + bool stream, + httplib::Response & res, + server_context & ctx_server, + const std::function(const std::function &)> & create_tasks, + const std::function &, httplib::DataSink & sink, const std::function &)> & payload) +{ + struct State { + std::unordered_set task_ids; + }; + auto state = std::make_shared(); + httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) { + if (!success && state) { + ctx_server.cancel_tasks(state->task_ids); + } + }; + if (!stream) { + res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { + auto is_alive = [&sink]() { return sink.is_writable(); }; + state->task_ids = create_tasks(is_alive); + payload(state->task_ids, sink, is_alive); + ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); + return false; + }, resource_releaser); + } else { + res.set_chunked_content_provider("text/event-stream", [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { + auto is_alive = [&sink]() { return sink.is_writable(); }; + state->task_ids = create_tasks(is_alive); + payload(state->task_ids, sink, is_alive); + ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); + return false; + }, resource_releaser); + } +} + +static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) { + res.status = status; + if (sink) { + res.set_header("Content-Type", MIMETYPE_JSON); + auto out = response.dump(-1, ' ', false, json::error_handler_t::replace); + sink->write(out.c_str(), out.size()); + } else { + res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + } +} + +static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) { + respond(res, sink, 200, {{"error", error_data}}); +} + +static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) { + respond(res, sink, 200, data); +} + int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -2479,18 +2544,7 @@ int main(int argc, char ** argv) { svr->set_logger(log_server_request); - auto res_error = [](httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); - }; - - auto res_ok = [](httplib::Response & res, const json & data) { - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - res.status = 200; - }; - - svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { std::string message; try { std::rethrow_exception(ep); @@ -2502,12 +2556,12 @@ int main(int argc, char ** argv) { json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, formatted_error); + res_error(res, /* sink= */ nullptr, formatted_error); }); - svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { - res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } // for other error codes, we skip processing here because it's already done by res_error() }); @@ -2535,7 +2589,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? static const std::unordered_set protected_endpoints = { "/props", @@ -2574,14 +2628,14 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WRN("Unauthorized: Invalid API Key\n"); return false; }; - auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { + auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { auto tmp = string_split(req.path, '.'); @@ -2589,7 +2643,7 @@ int main(int argc, char ** argv) { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; } else { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } return false; } @@ -2615,12 +2669,12 @@ int main(int argc, char ** argv) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { // error and loading states are handled by middleware json health = {{"status", "ok"}}; - res_ok(res, health); + res_ok(res, /* sink= */ nullptr, health); }; const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2640,17 +2694,17 @@ int main(int argc, char ** argv) { const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { if (n_idle_slots == 0) { - res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, result.data.at("slots")); + res_ok(res, /* sink= */ nullptr, result.data.at("slots")); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2759,11 +2813,11 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2783,17 +2837,17 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, result.data); + res_error(res, /* sink= */ nullptr, result.data); } else { - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); } }; - const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2813,13 +2867,13 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, result.data); + res_error(res, /* sink= */ nullptr, result.data); } else { - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); } }; - const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { @@ -2833,15 +2887,15 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, result.data); + res_error(res, /* sink= */ nullptr, result.data); } else { - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); } }; - const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { - res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2851,7 +2905,7 @@ int main(int argc, char ** argv) { try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -2864,11 +2918,11 @@ int main(int argc, char ** argv) { } else if (action == "erase") { handle_slots_erase(req, res, id_slot); } else { - res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); } }; - const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -2884,57 +2938,49 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() }, }; - res_ok(res, data); + res_ok(res, /* sink= */ nullptr, data); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - if (results.size() == 1) { - // single result - res_ok(res, results[0].data); - } else { - // multiple results (multitask) - json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); + handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function & is_alive) { + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + return server_task::get_list_id(tasks); + }, [stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function &) { + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, &sink, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); + } + res_ok(res, &sink, arr); } - res_ok(res, arr); - } - }, [&](const json & error_data) { - res_error(res, error_data); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { + }, [&res, &sink](json error_data) { + res_error(res, &sink, error_data); + }); + } else { + ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool { return server_sent_event(sink, "data", result.data); - }, [&](const json & error_data) { + }, [&sink](const json & error_data) { server_sent_event(sink, "error", error_data); }); sink.done(); - return false; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } + } + }); }; const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -2948,35 +2994,37 @@ int main(int argc, char ** argv) { }; // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, verbose](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - const auto completion_id = gen_chatcmplid(); + + handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function & is_alive) { + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); - res_ok(res, result_oai); - }, [&](const json & error_data) { - res_error(res, error_data); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { + return server_task::get_list_id(tasks); + }, [data, verbose, stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function & is_alive) { + const auto completion_id = gen_chatcmplid(); + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector & results) { + // multitask is never support in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); + res_ok(res, &sink, result_oai); + }, [&res, &sink](json error_data) { + res_error(res, &sink, error_data); + }); + } else { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + if (!is_alive()) { + return false; // connection is closed + } std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { @@ -2993,15 +3041,8 @@ int main(int argc, char ** argv) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); sink.done(); - return true; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } + } + }); }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -3021,7 +3062,7 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), MIMETYPE_JSON); }; - const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); json tokens_response = json::array(); @@ -3057,10 +3098,10 @@ int main(int argc, char ** argv) { } const json data = format_tokenizer_response(tokens_response); - res_ok(res, data); + res_ok(res, /* sink= */ nullptr, data); }; - const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -3070,13 +3111,13 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - res_ok(res, data); + res_ok(res, /* sink= */ nullptr, data); }; - const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) { // TODO: somehow clean up this checks in the future if (!ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3091,47 +3132,46 @@ int main(int argc, char ** argv) { // with "content", we only support single prompt prompt = std::vector{body.at("content")}; } else { - res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } - // create and queue the task - json responses = json::array(); - bool error = false; - { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + + handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function & is_alive) { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); - // get the result - std::unordered_set task_ids = server_task::get_list_id(tasks); + return server_task::get_list_id(tasks); + }, [is_openai, &ctx_server, &res, body](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function &) { + bool error = false; + json responses = json::array(); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector & results) { for (const auto & res : results) { responses.push_back(res.data); } - }, [&](const json & error_data) { - res_error(res, error_data); + }, [&res, &error](json error_data) { + res_error(res, /* sink= */ nullptr, error_data); error = true; }); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } + if (error) { + return; + } - if (error) { - return; - } - - // write JSON response - json root = is_openai - ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; - res_ok(res, root); + // write JSON response + json root = is_openai + ? format_embeddings_response_oaicompat(body, responses) + : responses[0]; + + res_ok(res, &sink, root); + }); }; - const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3149,17 +3189,17 @@ int main(int argc, char ** argv) { if (body.count("query") == 1) { query = body.at("query"); if (!query.is_string()) { - res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); return; } } else { - res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } std::vector documents = json_value(body, "documents", std::vector()); if (documents.empty()) { - res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3176,7 +3216,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; }); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3188,7 +3228,7 @@ int main(int argc, char ** argv) { responses.push_back(res.data); } }, [&](const json & error_data) { - res_error(res, error_data); + res_error(res, /* sink= */ nullptr, error_data); error = true; }); } @@ -3199,7 +3239,7 @@ int main(int argc, char ** argv) { // write JSON response json root = format_response_rerank(body, responses); - res_ok(res, root); + res_ok(res, /* sink= */ nullptr, root); }; const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { @@ -3212,7 +3252,7 @@ int main(int argc, char ** argv) { {"scale", lora.scale}, }); } - res_ok(res, result); + res_ok(res, /* sink= */ nullptr, result); res.status = 200; // HTTP OK }; @@ -3244,7 +3284,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); res.status = 200; // HTTP OK };