server: use (new) Request::is_alive as set_content_provider called after status / headers sent

This commit is contained in:
ochafik 2024-10-04 23:51:58 +01:00
parent 2dc708c72a
commit 6f693f14b0
2 changed files with 148 additions and 183 deletions

View file

@ -457,7 +457,6 @@ public:
std::function<bool(const char *data, size_t data_len)> write;
std::function<bool()> is_writable;
std::function<bool()> is_alive;
std::function<void()> done;
std::function<void(const Headers &trailer)> done_with_trailer;
std::ostream os;
@ -591,6 +590,7 @@ struct Response {
Headers headers;
std::string body;
std::string location; // Redirect location
std::function<bool()> is_alive;
bool has_header(const std::string &key) const;
std::string get_header_value(const std::string &key, size_t id = 0) const;
@ -4093,7 +4093,6 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
};
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
while (offset < end_offset && !is_shutting_down()) {
if (!strm.is_writable()) {
@ -4140,7 +4139,6 @@ write_content_without_length(Stream &strm,
};
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
data_sink.done = [&](void) { data_available = false; };
@ -4193,7 +4191,6 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
};
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
auto done_with_trailer = [&](const Headers *trailer) {
if (!ok) { return; }
@ -4287,6 +4284,7 @@ inline bool redirect(T &cli, Request &req, Response &res,
}
Response new_res;
new_res.is_alive = res.is_alive;
auto ret = cli.send(new_req, new_res, error);
if (ret) {
@ -6648,6 +6646,7 @@ Server::process_request(Stream &strm, bool close_connection,
Request req;
Response res;
res.is_alive = [&strm]() { return strm.is_alive(); };
res.version = "HTTP/1.1";
res.headers = default_headers_;

View file

@ -12,7 +12,6 @@
#include "json.hpp"
// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
#define MIMETYPE_EVENT_STREAM "text/event-stream"
// auto generated files (update with ./deps.sh)
#include "colorthemes.css.hpp"
@ -34,7 +33,6 @@
#include <atomic>
#include <condition_variable>
#include <functional>
#include <cstddef>
#include <cinttypes>
#include <deque>
@ -2435,66 +2433,6 @@ inline void signal_handler(int signal) {
shutdown_handler(signal);
}
static void handle_tasks(
bool stream,
httplib::Response & res,
server_context & ctx_server,
const std::function<std::unordered_set<int>(const std::function<bool()> &)> & create_tasks,
const std::function<void(const std::unordered_set<int> &, httplib::DataSink & sink)> & payload)
{
struct State {
std::unordered_set<int> task_ids;
bool is_sink_valid = true;
};
auto state = std::make_shared<State>();
httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) {
state->is_sink_valid = false;
if (!success && state) {
ctx_server.cancel_tasks(state->task_ids);
}
};
if (!stream) {
res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) {
auto is_alive = [state, &sink]() {
return state->is_sink_valid && sink.is_alive();
};
state->task_ids = create_tasks(is_alive);
payload(state->task_ids, sink);
ctx_server.queue_results.remove_waiting_task_ids(state->task_ids);
return false;
}, resource_releaser);
} else {
res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) {
auto is_alive = [state, &sink]() {
return state->is_sink_valid && sink.is_alive();
};
state->task_ids = create_tasks(is_alive);
payload(state->task_ids, sink);
ctx_server.queue_results.remove_waiting_task_ids(state->task_ids);
return false;
}, resource_releaser);
}
}
static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) {
res.status = status;
if (sink) {
res.set_header("Content-Type", MIMETYPE_JSON);
auto out = response.dump(-1, ' ', false, json::error_handler_t::replace);
sink->write(out.c_str(), out.size());
} else {
res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
}
}
static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) {
respond(res, sink, 500, {{"error", error_data}});
}
static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) {
respond(res, sink, 200, data);
}
int main(int argc, char ** argv) {
// own arguments required by this example
gpt_params params;
@ -2562,7 +2500,18 @@ int main(int argc, char ** argv) {
svr->set_logger(log_server_request);
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
auto res_error = [](httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response & res, const json & data) {
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
res.status = 200;
};
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
std::string message;
try {
std::rethrow_exception(ep);
@ -2574,12 +2523,12 @@ int main(int argc, char ** argv) {
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
res_error(res, /* sink= */ nullptr, formatted_error);
res_error(res, formatted_error);
});
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
}
// for other error codes, we skip processing here because it's already done by res_error()
});
@ -2607,7 +2556,7 @@ int main(int argc, char ** argv) {
// Middlewares
//
auto middleware_validate_api_key = [&params](const httplib::Request & req, httplib::Response & res) {
auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::unordered_set<std::string> protected_endpoints = {
"/props",
@ -2646,14 +2595,14 @@ int main(int argc, char ** argv) {
}
// API key is invalid or not provided
res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WRN("Unauthorized: Invalid API Key\n");
return false;
};
auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) {
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split(req.path, '.');
@ -2661,7 +2610,7 @@ int main(int argc, char ** argv) {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503;
} else {
res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
}
return false;
}
@ -2687,12 +2636,12 @@ int main(int argc, char ** argv) {
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
// error and loading states are handled by middleware
json health = {{"status", "ok"}};
res_ok(res, /* sink= */ nullptr, health);
res_ok(res, health);
};
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
if (!params.endpoint_slots) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -2712,17 +2661,17 @@ int main(int argc, char ** argv) {
const int n_idle_slots = result.data.at("idle");
if (req.has_param("fail_on_no_slot")) {
if (n_idle_slots == 0) {
res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return;
}
}
res_ok(res, /* sink= */ nullptr, result.data.at("slots"));
res_ok(res, result.data.at("slots"));
};
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_metrics) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -2831,11 +2780,11 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto handle_slots_save = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
@ -2855,17 +2804,17 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, /* sink= */ nullptr, result.data);
res_error(res, result.data);
} else {
res_ok(res, /* sink= */ nullptr, result.data);
res_ok(res, result.data);
}
};
const auto handle_slots_restore = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
@ -2885,13 +2834,13 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, /* sink= */ nullptr, result.data);
res_error(res, result.data);
} else {
res_ok(res, /* sink= */ nullptr, result.data);
res_ok(res, result.data);
}
};
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
server_task task;
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
task.data = {
@ -2905,15 +2854,15 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, /* sink= */ nullptr, result.data);
res_error(res, result.data);
} else {
res_ok(res, /* sink= */ nullptr, result.data);
res_ok(res, result.data);
}
};
const auto handle_slots_action = [&params, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
if (params.slot_save_path.empty()) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -2923,7 +2872,7 @@ int main(int argc, char ** argv) {
try {
id_slot = std::stoi(id_slot_str);
} catch (const std::exception &) {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
return;
}
@ -2936,11 +2885,11 @@ int main(int argc, char ** argv) {
} else if (action == "erase") {
handle_slots_erase(req, res, id_slot);
} else {
res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
}
};
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
std::string template_key = "tokenizer.chat_template", curr_tmpl;
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
if (tlen > 0) {
@ -2956,49 +2905,57 @@ int main(int argc, char ** argv) {
{ "chat_template", curr_tmpl.c_str() },
};
res_ok(res, /* sink= */ nullptr, data);
res_ok(res, data);
};
const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function<bool()> & is_alive) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
return server_task::get_list_id(tasks);
}, [stream, &res, &ctx_server](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) {
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector<server_task_result> & results) {
if (results.size() == 1) {
// single result
res_ok(res, &sink, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
res_ok(res, &sink, arr);
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
}, [&res, &sink](json error_data) {
res_error(res, &sink, error_data);
});
} else {
ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool {
res_ok(res, arr);
}
}, [&](const json & error_data) {
res_error(res, error_data);
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data);
}, [&sink](const json & error_data) {
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
sink.done();
}
});
return false;
};
auto on_complete = [task_ids, &ctx_server] (bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -3012,34 +2969,35 @@ int main(int argc, char ** argv) {
};
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, verbose](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function<bool()> & is_alive) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
}, [&](const json & error_data) {
res_error(res, error_data);
});
return server_task::get_list_id(tasks);
}, [data, verbose, stream, &res, &ctx_server](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) {
const auto completion_id = gen_chatcmplid();
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, &sink, result_oai);
}, [&res, &sink](json error_data) {
res_error(res, &sink, error_data);
});
} else {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
@ -3056,8 +3014,15 @@ int main(int argc, char ** argv) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
sink.done();
}
});
return true;
};
auto on_complete = [task_ids, &ctx_server] (bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
@ -3077,7 +3042,7 @@ int main(int argc, char ** argv) {
res.set_content(models.dump(), MIMETYPE_JSON);
};
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
json tokens_response = json::array();
@ -3113,10 +3078,10 @@ int main(int argc, char ** argv) {
}
const json data = format_tokenizer_response(tokens_response);
res_ok(res, /* sink= */ nullptr, data);
res_ok(res, data);
};
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
std::string content;
@ -3126,13 +3091,13 @@ int main(int argc, char ** argv) {
}
const json data = format_detokenized_response(content);
res_ok(res, /* sink= */ nullptr, data);
res_ok(res, data);
};
const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
// TODO: somehow clean up this checks in the future
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
const json body = json::parse(req.body);
@ -3147,46 +3112,47 @@ int main(int argc, char ** argv) {
// with "content", we only support single prompt
prompt = std::vector<std::string>{body.at("content")};
} else {
res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function<bool()> & is_alive) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive);
// create and queue the task
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
return server_task::get_list_id(tasks);
}, [is_openai, &ctx_server, &res, body](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) {
bool error = false;
json responses = json::array();
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector<server_task_result> & results) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
for (const auto & res : results) {
responses.push_back(res.data);
}
}, [&res, &error](json error_data) {
res_error(res, /* sink= */ nullptr, error_data);
}, [&](const json & error_data) {
res_error(res, error_data);
error = true;
});
if (error) {
return;
}
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
}
// write JSON response
json root = is_openai
? format_embeddings_response_oaicompat(body, responses)
: responses[0];
if (error) {
return;
}
res_ok(res, &sink, root);
});
// write JSON response
json root = is_openai
? format_embeddings_response_oaicompat(body, responses)
: responses[0];
res_ok(res, root);
};
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.reranking) {
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
const json body = json::parse(req.body);
@ -3204,17 +3170,17 @@ int main(int argc, char ** argv) {
if (body.count("query") == 1) {
query = body.at("query");
if (!query.is_string()) {
res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
} else {
res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
if (documents.empty()) {
res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return;
}
@ -3231,7 +3197,7 @@ int main(int argc, char ** argv) {
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; });
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, res.is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@ -3243,7 +3209,7 @@ int main(int argc, char ** argv) {
responses.push_back(res.data);
}
}, [&](const json & error_data) {
res_error(res, /* sink= */ nullptr, error_data);
res_error(res, error_data);
error = true;
});
}
@ -3254,7 +3220,7 @@ int main(int argc, char ** argv) {
// write JSON response
json root = format_response_rerank(body, responses);
res_ok(res, /* sink= */ nullptr, root);
res_ok(res, root);
};
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
@ -3267,7 +3233,7 @@ int main(int argc, char ** argv) {
{"scale", lora.scale},
});
}
res_ok(res, /* sink= */ nullptr, result);
res_ok(res, result);
res.status = 200; // HTTP OK
};
@ -3299,7 +3265,7 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
res_ok(res, /* sink= */ nullptr, result.data);
res_ok(res, result.data);
res.status = 200; // HTTP OK
};
@ -3454,4 +3420,4 @@ int main(int argc, char ** argv) {
t.join();
return 0;
}
}