diff --git a/examples/server/httplib.h b/examples/server/httplib.h index 025946180..05ee81a08 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -457,7 +457,6 @@ public: std::function write; std::function is_writable; - std::function is_alive; std::function done; std::function done_with_trailer; std::ostream os; @@ -591,6 +590,7 @@ struct Response { Headers headers; std::string body; std::string location; // Redirect location + std::function is_alive; bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, size_t id = 0) const; @@ -4093,7 +4093,6 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; - data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; while (offset < end_offset && !is_shutting_down()) { if (!strm.is_writable()) { @@ -4140,7 +4139,6 @@ write_content_without_length(Stream &strm, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; - data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; data_sink.done = [&](void) { data_available = false; }; @@ -4193,7 +4191,6 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; - data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -4287,6 +4284,7 @@ inline bool redirect(T &cli, Request &req, Response &res, } Response new_res; + new_res.is_alive = res.is_alive; auto ret = cli.send(new_req, new_res, error); if (ret) { @@ -6648,6 +6646,7 @@ Server::process_request(Stream &strm, bool close_connection, Request req; Response res; + res.is_alive = [&strm]() { return strm.is_alive(); }; res.version = "HTTP/1.1"; res.headers = default_headers_; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93c6c43a8..880862aac 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -12,7 +12,6 @@ #include "json.hpp" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" -#define MIMETYPE_EVENT_STREAM "text/event-stream" // auto generated files (update with ./deps.sh) #include "colorthemes.css.hpp" @@ -34,7 +33,6 @@ #include #include -#include #include #include #include @@ -2435,66 +2433,6 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } -static void handle_tasks( - bool stream, - httplib::Response & res, - server_context & ctx_server, - const std::function(const std::function &)> & create_tasks, - const std::function &, httplib::DataSink & sink)> & payload) -{ - struct State { - std::unordered_set task_ids; - bool is_sink_valid = true; - }; - auto state = std::make_shared(); - httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) { - state->is_sink_valid = false; - if (!success && state) { - ctx_server.cancel_tasks(state->task_ids); - } - }; - if (!stream) { - res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [state, &sink]() { - return state->is_sink_valid && sink.is_alive(); - }; - state->task_ids = create_tasks(is_alive); - payload(state->task_ids, sink); - ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); - return false; - }, resource_releaser); - } else { - res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [state, &sink]() { - return state->is_sink_valid && sink.is_alive(); - }; - state->task_ids = create_tasks(is_alive); - payload(state->task_ids, sink); - ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); - return false; - }, resource_releaser); - } -} - -static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) { - res.status = status; - if (sink) { - res.set_header("Content-Type", MIMETYPE_JSON); - auto out = response.dump(-1, ' ', false, json::error_handler_t::replace); - sink->write(out.c_str(), out.size()); - } else { - res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } -} - -static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) { - respond(res, sink, 500, {{"error", error_data}}); -} - -static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) { - respond(res, sink, 200, data); -} - int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -2562,7 +2500,18 @@ int main(int argc, char ** argv) { svr->set_logger(log_server_request); - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + auto res_error = [](httplib::Response & res, const json & error_data) { + json final_response {{"error", error_data}}; + res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); + }; + + auto res_ok = [](httplib::Response & res, const json & data) { + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.status = 200; + }; + + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { std::string message; try { std::rethrow_exception(ep); @@ -2574,12 +2523,12 @@ int main(int argc, char ** argv) { json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, /* sink= */ nullptr, formatted_error); + res_error(res, formatted_error); }); - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { - res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } // for other error codes, we skip processing here because it's already done by res_error() }); @@ -2607,7 +2556,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? static const std::unordered_set protected_endpoints = { "/props", @@ -2646,14 +2595,14 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); + res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WRN("Unauthorized: Invalid API Key\n"); return false; }; - auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { + auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { auto tmp = string_split(req.path, '.'); @@ -2661,7 +2610,7 @@ int main(int argc, char ** argv) { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; } else { - res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } return false; } @@ -2687,12 +2636,12 @@ int main(int argc, char ** argv) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { // error and loading states are handled by middleware json health = {{"status", "ok"}}; - res_ok(res, /* sink= */ nullptr, health); + res_ok(res, health); }; const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2712,17 +2661,17 @@ int main(int argc, char ** argv) { const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { if (n_idle_slots == 0) { - res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, /* sink= */ nullptr, result.data.at("slots")); + res_ok(res, result.data.at("slots")); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2831,11 +2780,11 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2855,17 +2804,17 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, /* sink= */ nullptr, result.data); + res_error(res, result.data); } else { - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); } }; - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2885,13 +2834,13 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, /* sink= */ nullptr, result.data); + res_error(res, result.data); } else { - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); } }; - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { @@ -2905,15 +2854,15 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, /* sink= */ nullptr, result.data); + res_error(res, result.data); } else { - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); } }; - const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2923,7 +2872,7 @@ int main(int argc, char ** argv) { try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -2936,11 +2885,11 @@ int main(int argc, char ** argv) { } else if (action == "erase") { handle_slots_erase(req, res, id_slot); } else { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); } }; - const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -2956,49 +2905,57 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() }, }; - res_ok(res, /* sink= */ nullptr, data); + res_ok(res, data); }; - const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, res.is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); - handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function & is_alive) { - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - return server_task::get_list_id(tasks); - }, [stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink) { - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector & results) { - if (results.size() == 1) { - // single result - res_ok(res, &sink, results[0].data); - } else { - // multiple results (multitask) - json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); - } - res_ok(res, &sink, arr); + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); } - }, [&res, &sink](json error_data) { - res_error(res, &sink, error_data); - }); - } else { - ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool { + res_ok(res, arr); + } + }, [&](const json & error_data) { + res_error(res, error_data); + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { return server_sent_event(sink, "data", result.data); - }, [&sink](const json & error_data) { + }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); sink.done(); - } - }); + return false; + }; + + auto on_complete = [task_ids, &ctx_server] (bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } }; const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3012,34 +2969,35 @@ int main(int argc, char ** argv) { }; // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, res.is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); + const auto completion_id = gen_chatcmplid(); - handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function & is_alive) { - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { + // multitask is never support in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); + res_ok(res, result_oai); + }, [&](const json & error_data) { + res_error(res, error_data); + }); - return server_task::get_list_id(tasks); - }, [data, verbose, stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink) { - const auto completion_id = gen_chatcmplid(); - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); - res_ok(res, &sink, result_oai); - }, [&res, &sink](json error_data) { - res_error(res, &sink, error_data); - }); - } else { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { @@ -3056,8 +3014,15 @@ int main(int argc, char ** argv) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); sink.done(); - } - }); + return true; + }; + + auto on_complete = [task_ids, &ctx_server] (bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -3077,7 +3042,7 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), MIMETYPE_JSON); }; - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); json tokens_response = json::array(); @@ -3113,10 +3078,10 @@ int main(int argc, char ** argv) { } const json data = format_tokenizer_response(tokens_response); - res_ok(res, /* sink= */ nullptr, data); + res_ok(res, data); }; - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -3126,13 +3091,13 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - res_ok(res, /* sink= */ nullptr, data); + res_ok(res, data); }; - const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { // TODO: somehow clean up this checks in the future if (!ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3147,46 +3112,47 @@ int main(int argc, char ** argv) { // with "content", we only support single prompt prompt = std::vector{body.at("content")}; } else { - res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } - - handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function & is_alive) { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); - return server_task::get_list_id(tasks); - }, [is_openai, &ctx_server, &res, body](const std::unordered_set & task_ids, httplib::DataSink & sink) { - bool error = false; - json responses = json::array(); + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector & results) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { for (const auto & res : results) { responses.push_back(res.data); } - }, [&res, &error](json error_data) { - res_error(res, /* sink= */ nullptr, error_data); + }, [&](const json & error_data) { + res_error(res, error_data); error = true; }); - if (error) { - return; - } + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } - // write JSON response - json root = is_openai - ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; + if (error) { + return; + } - res_ok(res, &sink, root); - }); + // write JSON response + json root = is_openai + ? format_embeddings_response_oaicompat(body, responses) + : responses[0]; + res_ok(res, root); }; - const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3204,17 +3170,17 @@ int main(int argc, char ** argv) { if (body.count("query") == 1) { query = body.at("query"); if (!query.is_string()) { - res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); return; } } else { - res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } std::vector documents = json_value(body, "documents", std::vector()); if (documents.empty()) { - res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3231,7 +3197,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; }); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3243,7 +3209,7 @@ int main(int argc, char ** argv) { responses.push_back(res.data); } }, [&](const json & error_data) { - res_error(res, /* sink= */ nullptr, error_data); + res_error(res, error_data); error = true; }); } @@ -3254,7 +3220,7 @@ int main(int argc, char ** argv) { // write JSON response json root = format_response_rerank(body, responses); - res_ok(res, /* sink= */ nullptr, root); + res_ok(res, root); }; const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { @@ -3267,7 +3233,7 @@ int main(int argc, char ** argv) { {"scale", lora.scale}, }); } - res_ok(res, /* sink= */ nullptr, result); + res_ok(res, result); res.status = 200; // HTTP OK }; @@ -3299,7 +3265,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); res.status = 200; // HTTP OK }; @@ -3454,4 +3420,4 @@ int main(int argc, char ** argv) { t.join(); return 0; -} +} \ No newline at end of file