server: use (new) Request::is_alive as set_content_provider called after status / headers sent

This commit is contained in:
ochafik 2024-10-04 23:51:58 +01:00
parent 2dc708c72a
commit 6f693f14b0
2 changed files with 148 additions and 183 deletions

View file

@ -457,7 +457,6 @@ public:
std::function<bool(const char *data, size_t data_len)> write; std::function<bool(const char *data, size_t data_len)> write;
std::function<bool()> is_writable; std::function<bool()> is_writable;
std::function<bool()> is_alive;
std::function<void()> done; std::function<void()> done;
std::function<void(const Headers &trailer)> done_with_trailer; std::function<void(const Headers &trailer)> done_with_trailer;
std::ostream os; std::ostream os;
@ -591,6 +590,7 @@ struct Response {
Headers headers; Headers headers;
std::string body; std::string body;
std::string location; // Redirect location std::string location; // Redirect location
std::function<bool()> is_alive;
bool has_header(const std::string &key) const; bool has_header(const std::string &key) const;
std::string get_header_value(const std::string &key, size_t id = 0) const; std::string get_header_value(const std::string &key, size_t id = 0) const;
@ -4093,7 +4093,6 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
}; };
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
while (offset < end_offset && !is_shutting_down()) { while (offset < end_offset && !is_shutting_down()) {
if (!strm.is_writable()) { if (!strm.is_writable()) {
@ -4140,7 +4139,6 @@ write_content_without_length(Stream &strm,
}; };
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
data_sink.done = [&](void) { data_available = false; }; data_sink.done = [&](void) { data_available = false; };
@ -4193,7 +4191,6 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
}; };
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
auto done_with_trailer = [&](const Headers *trailer) { auto done_with_trailer = [&](const Headers *trailer) {
if (!ok) { return; } if (!ok) { return; }
@ -4287,6 +4284,7 @@ inline bool redirect(T &cli, Request &req, Response &res,
} }
Response new_res; Response new_res;
new_res.is_alive = res.is_alive;
auto ret = cli.send(new_req, new_res, error); auto ret = cli.send(new_req, new_res, error);
if (ret) { if (ret) {
@ -6648,6 +6646,7 @@ Server::process_request(Stream &strm, bool close_connection,
Request req; Request req;
Response res; Response res;
res.is_alive = [&strm]() { return strm.is_alive(); };
res.version = "HTTP/1.1"; res.version = "HTTP/1.1";
res.headers = default_headers_; res.headers = default_headers_;

View file

@ -12,7 +12,6 @@
#include "json.hpp" #include "json.hpp"
// mime type for sending response // mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8" #define MIMETYPE_JSON "application/json; charset=utf-8"
#define MIMETYPE_EVENT_STREAM "text/event-stream"
// auto generated files (update with ./deps.sh) // auto generated files (update with ./deps.sh)
#include "colorthemes.css.hpp" #include "colorthemes.css.hpp"
@ -34,7 +33,6 @@
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
#include <functional>
#include <cstddef> #include <cstddef>
#include <cinttypes> #include <cinttypes>
#include <deque> #include <deque>
@ -2435,66 +2433,6 @@ inline void signal_handler(int signal) {
shutdown_handler(signal); shutdown_handler(signal);
} }
static void handle_tasks(
bool stream,
httplib::Response & res,
server_context & ctx_server,
const std::function<std::unordered_set<int>(const std::function<bool()> &)> & create_tasks,
const std::function<void(const std::unordered_set<int> &, httplib::DataSink & sink)> & payload)
{
struct State {
std::unordered_set<int> task_ids;
bool is_sink_valid = true;
};
auto state = std::make_shared<State>();
httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) {
state->is_sink_valid = false;
if (!success && state) {
ctx_server.cancel_tasks(state->task_ids);
}
};
if (!stream) {
res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) {
auto is_alive = [state, &sink]() {
return state->is_sink_valid && sink.is_alive();
};
state->task_ids = create_tasks(is_alive);
payload(state->task_ids, sink);
ctx_server.queue_results.remove_waiting_task_ids(state->task_ids);
return false;
}, resource_releaser);
} else {
res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) {
auto is_alive = [state, &sink]() {
return state->is_sink_valid && sink.is_alive();
};
state->task_ids = create_tasks(is_alive);
payload(state->task_ids, sink);
ctx_server.queue_results.remove_waiting_task_ids(state->task_ids);
return false;
}, resource_releaser);
}
}
static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) {
res.status = status;
if (sink) {
res.set_header("Content-Type", MIMETYPE_JSON);
auto out = response.dump(-1, ' ', false, json::error_handler_t::replace);
sink->write(out.c_str(), out.size());
} else {
res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
}
}
static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) {
respond(res, sink, 500, {{"error", error_data}});
}
static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) {
respond(res, sink, 200, data);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
// own arguments required by this example // own arguments required by this example
gpt_params params; gpt_params params;
@ -2562,7 +2500,18 @@ int main(int argc, char ** argv) {
svr->set_logger(log_server_request); svr->set_logger(log_server_request);
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { auto res_error = [](httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response & res, const json & data) {
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
res.status = 200;
};
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
std::string message; std::string message;
try { try {
std::rethrow_exception(ep); std::rethrow_exception(ep);
@ -2574,12 +2523,12 @@ int main(int argc, char ** argv) {
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
res_error(res, /* sink= */ nullptr, formatted_error); res_error(res, formatted_error);
}); });
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) { if (res.status == 404) {
res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
} }
// for other error codes, we skip processing here because it's already done by res_error() // for other error codes, we skip processing here because it's already done by res_error()
}); });
@ -2607,7 +2556,7 @@ int main(int argc, char ** argv) {
// Middlewares // Middlewares
// //
auto middleware_validate_api_key = [&params](const httplib::Request & req, httplib::Response & res) { auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"? // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::unordered_set<std::string> protected_endpoints = { static const std::unordered_set<std::string> protected_endpoints = {
"/props", "/props",
@ -2646,14 +2595,14 @@ int main(int argc, char ** argv) {
} }
// API key is invalid or not provided // API key is invalid or not provided
res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WRN("Unauthorized: Invalid API Key\n"); LOG_WRN("Unauthorized: Invalid API Key\n");
return false; return false;
}; };
auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load(); server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) { if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split(req.path, '.'); auto tmp = string_split(req.path, '.');
@ -2661,7 +2610,7 @@ int main(int argc, char ** argv) {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8"); res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503; res.status = 503;
} else { } else {
res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} }
return false; return false;
} }
@ -2687,12 +2636,12 @@ int main(int argc, char ** argv) {
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
// error and loading states are handled by middleware // error and loading states are handled by middleware
json health = {{"status", "ok"}}; json health = {{"status", "ok"}};
res_ok(res, /* sink= */ nullptr, health); res_ok(res, health);
}; };
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
if (!params.endpoint_slots) { if (!params.endpoint_slots) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2712,17 +2661,17 @@ int main(int argc, char ** argv) {
const int n_idle_slots = result.data.at("idle"); const int n_idle_slots = result.data.at("idle");
if (req.has_param("fail_on_no_slot")) { if (req.has_param("fail_on_no_slot")) {
if (n_idle_slots == 0) { if (n_idle_slots == 0) {
res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return; return;
} }
} }
res_ok(res, /* sink= */ nullptr, result.data.at("slots")); res_ok(res, result.data.at("slots"));
}; };
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_metrics) { if (!params.endpoint_metrics) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2831,11 +2780,11 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
const auto handle_slots_save = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) { const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body); json request_data = json::parse(req.body);
std::string filename = request_data.at("filename"); std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) { if (!fs_validate_filename(filename)) {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
std::string filepath = params.slot_save_path + filename; std::string filepath = params.slot_save_path + filename;
@ -2855,17 +2804,17 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) { if (result.error) {
res_error(res, /* sink= */ nullptr, result.data); res_error(res, result.data);
} else { } else {
res_ok(res, /* sink= */ nullptr, result.data); res_ok(res, result.data);
} }
}; };
const auto handle_slots_restore = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) { const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body); json request_data = json::parse(req.body);
std::string filename = request_data.at("filename"); std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) { if (!fs_validate_filename(filename)) {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
std::string filepath = params.slot_save_path + filename; std::string filepath = params.slot_save_path + filename;
@ -2885,13 +2834,13 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) { if (result.error) {
res_error(res, /* sink= */ nullptr, result.data); res_error(res, result.data);
} else { } else {
res_ok(res, /* sink= */ nullptr, result.data); res_ok(res, result.data);
} }
}; };
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
server_task task; server_task task;
task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.type = SERVER_TASK_TYPE_SLOT_ERASE;
task.data = { task.data = {
@ -2905,15 +2854,15 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) { if (result.error) {
res_error(res, /* sink= */ nullptr, result.data); res_error(res, result.data);
} else { } else {
res_ok(res, /* sink= */ nullptr, result.data); res_ok(res, result.data);
} }
}; };
const auto handle_slots_action = [&params, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
if (params.slot_save_path.empty()) { if (params.slot_save_path.empty()) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2923,7 +2872,7 @@ int main(int argc, char ** argv) {
try { try {
id_slot = std::stoi(id_slot_str); id_slot = std::stoi(id_slot_str);
} catch (const std::exception &) { } catch (const std::exception &) {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
@ -2936,11 +2885,11 @@ int main(int argc, char ** argv) {
} else if (action == "erase") { } else if (action == "erase") {
handle_slots_erase(req, res, id_slot); handle_slots_erase(req, res, id_slot);
} else { } else {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
} }
}; };
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
std::string template_key = "tokenizer.chat_template", curr_tmpl; std::string template_key = "tokenizer.chat_template", curr_tmpl;
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
if (tlen > 0) { if (tlen > 0) {
@ -2956,49 +2905,57 @@ int main(int argc, char ** argv) {
{ "chat_template", curr_tmpl.c_str() }, { "chat_template", curr_tmpl.c_str() },
}; };
res_ok(res, /* sink= */ nullptr, data); res_ok(res, data);
}; };
const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function<bool()> & is_alive) { if (!stream) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive); ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
ctx_server.queue_results.add_waiting_tasks(tasks); if (results.size() == 1) {
ctx_server.queue_tasks.post(tasks); // single result
res_ok(res, results[0].data);
return server_task::get_list_id(tasks); } else {
}, [stream, &res, &ctx_server](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) { // multiple results (multitask)
if (!stream) { json arr = json::array();
ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector<server_task_result> & results) { for (const auto & res : results) {
if (results.size() == 1) { arr.push_back(res.data);
// single result
res_ok(res, &sink, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
res_ok(res, &sink, arr);
} }
}, [&res, &sink](json error_data) { res_ok(res, arr);
res_error(res, &sink, error_data); }
}); }, [&](const json & error_data) {
} else { res_error(res, error_data);
ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool { });
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data); return server_sent_event(sink, "data", result.data);
}, [&sink](const json & error_data) { }, [&](const json & error_data) {
server_sent_event(sink, "error", error_data); server_sent_event(sink, "error", error_data);
}); });
sink.done(); sink.done();
} return false;
}); };
auto on_complete = [task_ids, &ctx_server] (bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
}; };
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -3012,34 +2969,35 @@ int main(int argc, char ** argv) {
}; };
// TODO: maybe merge this function with "handle_completions_generic" // TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, verbose](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function<bool()> & is_alive) { if (!stream) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
ctx_server.queue_results.add_waiting_tasks(tasks); // multitask is never support in chat completion, there is only one result
ctx_server.queue_tasks.post(tasks); json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
}, [&](const json & error_data) {
res_error(res, error_data);
});
return server_task::get_list_id(tasks); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
}, [data, verbose, stream, &res, &ctx_server](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) { } else {
const auto completion_id = gen_chatcmplid(); const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
if (!stream) { ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, &sink, result_oai);
}, [&res, &sink](json error_data) {
res_error(res, &sink, error_data);
});
} else {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id); std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) { for (auto & event_data : result_array) {
if (event_data.empty()) { if (event_data.empty()) {
@ -3056,8 +3014,15 @@ int main(int argc, char ** argv) {
static const std::string ev_done = "data: [DONE]\n\n"; static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size()); sink.write(ev_done.data(), ev_done.size());
sink.done(); sink.done();
} return true;
}); };
auto on_complete = [task_ids, &ctx_server] (bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
}; };
const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) { const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
@ -3077,7 +3042,7 @@ int main(int argc, char ** argv) {
res.set_content(models.dump(), MIMETYPE_JSON); res.set_content(models.dump(), MIMETYPE_JSON);
}; };
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
json tokens_response = json::array(); json tokens_response = json::array();
@ -3113,10 +3078,10 @@ int main(int argc, char ** argv) {
} }
const json data = format_tokenizer_response(tokens_response); const json data = format_tokenizer_response(tokens_response);
res_ok(res, /* sink= */ nullptr, data); res_ok(res, data);
}; };
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;
@ -3126,13 +3091,13 @@ int main(int argc, char ** argv) {
} }
const json data = format_detokenized_response(content); const json data = format_detokenized_response(content);
res_ok(res, /* sink= */ nullptr, data); res_ok(res, data);
}; };
const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
// TODO: somehow clean up this checks in the future // TODO: somehow clean up this checks in the future
if (!ctx_server.params.embedding || ctx_server.params.reranking) { if (!ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
const json body = json::parse(req.body); const json body = json::parse(req.body);
@ -3147,46 +3112,47 @@ int main(int argc, char ** argv) {
// with "content", we only support single prompt // with "content", we only support single prompt
prompt = std::vector<std::string>{body.at("content")}; prompt = std::vector<std::string>{body.at("content")};
} else { } else {
res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
// create and queue the task
handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function<bool()> & is_alive) { json responses = json::array();
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
return server_task::get_list_id(tasks); // get the result
}, [is_openai, &ctx_server, &res, body](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) { std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
bool error = false;
json responses = json::array();
ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
for (const auto & res : results) { for (const auto & res : results) {
responses.push_back(res.data); responses.push_back(res.data);
} }
}, [&res, &error](json error_data) { }, [&](const json & error_data) {
res_error(res, /* sink= */ nullptr, error_data); res_error(res, error_data);
error = true; error = true;
}); });
if (error) { ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return; }
}
// write JSON response if (error) {
json root = is_openai return;
? format_embeddings_response_oaicompat(body, responses) }
: responses[0];
res_ok(res, &sink, root); // write JSON response
}); json root = is_openai
? format_embeddings_response_oaicompat(body, responses)
: responses[0];
res_ok(res, root);
}; };
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.reranking) { if (!ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
const json body = json::parse(req.body); const json body = json::parse(req.body);
@ -3204,17 +3170,17 @@ int main(int argc, char ** argv) {
if (body.count("query") == 1) { if (body.count("query") == 1) {
query = body.at("query"); query = body.at("query");
if (!query.is_string()) { if (!query.is_string()) {
res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
} else { } else {
res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>()); std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
if (documents.empty()) { if (documents.empty()) {
res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
@ -3231,7 +3197,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; }); std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3243,7 +3209,7 @@ int main(int argc, char ** argv) {
responses.push_back(res.data); responses.push_back(res.data);
} }
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, /* sink= */ nullptr, error_data); res_error(res, error_data);
error = true; error = true;
}); });
} }
@ -3254,7 +3220,7 @@ int main(int argc, char ** argv) {
// write JSON response // write JSON response
json root = format_response_rerank(body, responses); json root = format_response_rerank(body, responses);
res_ok(res, /* sink= */ nullptr, root); res_ok(res, root);
}; };
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
@ -3267,7 +3233,7 @@ int main(int argc, char ** argv) {
{"scale", lora.scale}, {"scale", lora.scale},
}); });
} }
res_ok(res, /* sink= */ nullptr, result); res_ok(res, result);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
@ -3299,7 +3265,7 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
res_ok(res, /* sink= */ nullptr, result.data); res_ok(res, result.data);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
@ -3454,4 +3420,4 @@ int main(int argc, char ** argv) {
t.join(); t.join();
return 0; return 0;
} }