server
: use (new) Request::is_alive as set_content_provider called after status / headers sent
This commit is contained in:
parent
2dc708c72a
commit
6f693f14b0
2 changed files with 148 additions and 183 deletions
|
@ -457,7 +457,6 @@ public:
|
|||
|
||||
std::function<bool(const char *data, size_t data_len)> write;
|
||||
std::function<bool()> is_writable;
|
||||
std::function<bool()> is_alive;
|
||||
std::function<void()> done;
|
||||
std::function<void(const Headers &trailer)> done_with_trailer;
|
||||
std::ostream os;
|
||||
|
@ -591,6 +590,7 @@ struct Response {
|
|||
Headers headers;
|
||||
std::string body;
|
||||
std::string location; // Redirect location
|
||||
std::function<bool()> is_alive;
|
||||
|
||||
bool has_header(const std::string &key) const;
|
||||
std::string get_header_value(const std::string &key, size_t id = 0) const;
|
||||
|
@ -4093,7 +4093,6 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
|
|||
};
|
||||
|
||||
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
|
||||
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
|
||||
|
||||
while (offset < end_offset && !is_shutting_down()) {
|
||||
if (!strm.is_writable()) {
|
||||
|
@ -4140,7 +4139,6 @@ write_content_without_length(Stream &strm,
|
|||
};
|
||||
|
||||
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
|
||||
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
|
||||
|
||||
data_sink.done = [&](void) { data_available = false; };
|
||||
|
||||
|
@ -4193,7 +4191,6 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
|
|||
};
|
||||
|
||||
data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
|
||||
data_sink.is_alive = [&]() -> bool { return strm.is_alive(); };
|
||||
|
||||
auto done_with_trailer = [&](const Headers *trailer) {
|
||||
if (!ok) { return; }
|
||||
|
@ -4287,6 +4284,7 @@ inline bool redirect(T &cli, Request &req, Response &res,
|
|||
}
|
||||
|
||||
Response new_res;
|
||||
new_res.is_alive = res.is_alive;
|
||||
|
||||
auto ret = cli.send(new_req, new_res, error);
|
||||
if (ret) {
|
||||
|
@ -6648,6 +6646,7 @@ Server::process_request(Stream &strm, bool close_connection,
|
|||
Request req;
|
||||
|
||||
Response res;
|
||||
res.is_alive = [&strm]() { return strm.is_alive(); };
|
||||
res.version = "HTTP/1.1";
|
||||
res.headers = default_headers_;
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
#include "json.hpp"
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
#define MIMETYPE_EVENT_STREAM "text/event-stream"
|
||||
|
||||
// auto generated files (update with ./deps.sh)
|
||||
#include "colorthemes.css.hpp"
|
||||
|
@ -34,7 +33,6 @@
|
|||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
#include <deque>
|
||||
|
@ -2435,66 +2433,6 @@ inline void signal_handler(int signal) {
|
|||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
static void handle_tasks(
|
||||
bool stream,
|
||||
httplib::Response & res,
|
||||
server_context & ctx_server,
|
||||
const std::function<std::unordered_set<int>(const std::function<bool()> &)> & create_tasks,
|
||||
const std::function<void(const std::unordered_set<int> &, httplib::DataSink & sink)> & payload)
|
||||
{
|
||||
struct State {
|
||||
std::unordered_set<int> task_ids;
|
||||
bool is_sink_valid = true;
|
||||
};
|
||||
auto state = std::make_shared<State>();
|
||||
httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) {
|
||||
state->is_sink_valid = false;
|
||||
if (!success && state) {
|
||||
ctx_server.cancel_tasks(state->task_ids);
|
||||
}
|
||||
};
|
||||
if (!stream) {
|
||||
res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
auto is_alive = [state, &sink]() {
|
||||
return state->is_sink_valid && sink.is_alive();
|
||||
};
|
||||
state->task_ids = create_tasks(is_alive);
|
||||
payload(state->task_ids, sink);
|
||||
ctx_server.queue_results.remove_waiting_task_ids(state->task_ids);
|
||||
return false;
|
||||
}, resource_releaser);
|
||||
} else {
|
||||
res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
auto is_alive = [state, &sink]() {
|
||||
return state->is_sink_valid && sink.is_alive();
|
||||
};
|
||||
state->task_ids = create_tasks(is_alive);
|
||||
payload(state->task_ids, sink);
|
||||
ctx_server.queue_results.remove_waiting_task_ids(state->task_ids);
|
||||
return false;
|
||||
}, resource_releaser);
|
||||
}
|
||||
}
|
||||
|
||||
static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) {
|
||||
res.status = status;
|
||||
if (sink) {
|
||||
res.set_header("Content-Type", MIMETYPE_JSON);
|
||||
auto out = response.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
sink->write(out.c_str(), out.size());
|
||||
} else {
|
||||
res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
}
|
||||
}
|
||||
|
||||
static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) {
|
||||
respond(res, sink, 500, {{"error", error_data}});
|
||||
}
|
||||
|
||||
static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) {
|
||||
respond(res, sink, 200, data);
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
// own arguments required by this example
|
||||
gpt_params params;
|
||||
|
@ -2562,7 +2500,18 @@ int main(int argc, char ** argv) {
|
|||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
auto res_error = [](httplib::Response & res, const json & error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
auto res_ok = [](httplib::Response & res, const json & data) {
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
res.status = 200;
|
||||
};
|
||||
|
||||
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
std::string message;
|
||||
try {
|
||||
std::rethrow_exception(ep);
|
||||
|
@ -2574,12 +2523,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
||||
res_error(res, /* sink= */ nullptr, formatted_error);
|
||||
res_error(res, formatted_error);
|
||||
});
|
||||
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res_error()
|
||||
});
|
||||
|
@ -2607,7 +2556,7 @@ int main(int argc, char ** argv) {
|
|||
// Middlewares
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) {
|
||||
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||
static const std::unordered_set<std::string> protected_endpoints = {
|
||||
"/props",
|
||||
|
@ -2646,14 +2595,14 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WRN("Unauthorized: Invalid API Key\n");
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) {
|
||||
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||
auto tmp = string_split(req.path, '.');
|
||||
|
@ -2661,7 +2610,7 @@ int main(int argc, char ** argv) {
|
|||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||
res.status = 503;
|
||||
} else {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -2687,12 +2636,12 @@ int main(int argc, char ** argv) {
|
|||
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
|
||||
// error and loading states are handled by middleware
|
||||
json health = {{"status", "ok"}};
|
||||
res_ok(res, /* sink= */ nullptr, health);
|
||||
res_ok(res, health);
|
||||
};
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2712,17 +2661,17 @@ int main(int argc, char ** argv) {
|
|||
const int n_idle_slots = result.data.at("idle");
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
if (n_idle_slots == 0) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
res_ok(res, /* sink= */ nullptr, result.data.at("slots"));
|
||||
res_ok(res, result.data.at("slots"));
|
||||
};
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2831,11 +2780,11 @@ int main(int argc, char ** argv) {
|
|||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
|
@ -2855,17 +2804,17 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, /* sink= */ nullptr, result.data);
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res_ok(res, /* sink= */ nullptr, result.data);
|
||||
res_ok(res, result.data);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
|
@ -2885,13 +2834,13 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, /* sink= */ nullptr, result.data);
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res_ok(res, /* sink= */ nullptr, result.data);
|
||||
res_ok(res, result.data);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
||||
task.data = {
|
||||
|
@ -2905,15 +2854,15 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, /* sink= */ nullptr, result.data);
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res_ok(res, /* sink= */ nullptr, result.data);
|
||||
res_ok(res, result.data);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
if (params.slot_save_path.empty()) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2923,7 +2872,7 @@ int main(int argc, char ** argv) {
|
|||
try {
|
||||
id_slot = std::stoi(id_slot_str);
|
||||
} catch (const std::exception &) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2936,11 +2885,11 @@ int main(int argc, char ** argv) {
|
|||
} else if (action == "erase") {
|
||||
handle_slots_erase(req, res, id_slot);
|
||||
} else {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
||||
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
|
@ -2956,49 +2905,57 @@ int main(int argc, char ** argv) {
|
|||
{ "chat_template", curr_tmpl.c_str() },
|
||||
};
|
||||
|
||||
res_ok(res, /* sink= */ nullptr, data);
|
||||
res_ok(res, data);
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, res.is_alive);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function<bool()> & is_alive) {
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
return server_task::get_list_id(tasks);
|
||||
}, [stream, &res, &ctx_server](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) {
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector<server_task_result> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, &sink, results[0].data);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
}
|
||||
res_ok(res, &sink, arr);
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, results[0].data);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
}
|
||||
}, [&res, &sink](json error_data) {
|
||||
res_error(res, &sink, error_data);
|
||||
});
|
||||
} else {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool {
|
||||
res_ok(res, arr);
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
}, [&sink](const json & error_data) {
|
||||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
sink.done();
|
||||
}
|
||||
});
|
||||
return false;
|
||||
};
|
||||
|
||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
|
@ -3012,34 +2969,35 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, res.is_alive);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
|
||||
handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function<bool()> & is_alive) {
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
||||
res_ok(res, result_oai);
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
|
||||
return server_task::get_list_id(tasks);
|
||||
}, [data, verbose, stream, &res, &ctx_server](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) {
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector<server_task_result> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
||||
res_ok(res, &sink, result_oai);
|
||||
}, [&res, &sink](json error_data) {
|
||||
res_error(res, &sink, error_data);
|
||||
});
|
||||
} else {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
||||
for (auto & event_data : result_array) {
|
||||
if (event_data.empty()) {
|
||||
|
@ -3056,8 +3014,15 @@ int main(int argc, char ** argv) {
|
|||
static const std::string ev_done = "data: [DONE]\n\n";
|
||||
sink.write(ev_done.data(), ev_done.size());
|
||||
sink.done();
|
||||
}
|
||||
});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
|
@ -3077,7 +3042,7 @@ int main(int argc, char ** argv) {
|
|||
res.set_content(models.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
json tokens_response = json::array();
|
||||
|
@ -3113,10 +3078,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const json data = format_tokenizer_response(tokens_response);
|
||||
res_ok(res, /* sink= */ nullptr, data);
|
||||
res_ok(res, data);
|
||||
};
|
||||
|
||||
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
std::string content;
|
||||
|
@ -3126,13 +3091,13 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const json data = format_detokenized_response(content);
|
||||
res_ok(res, /* sink= */ nullptr, data);
|
||||
res_ok(res, data);
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: somehow clean up this checks in the future
|
||||
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
const json body = json::parse(req.body);
|
||||
|
@ -3147,46 +3112,47 @@ int main(int argc, char ** argv) {
|
|||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body.at("content")};
|
||||
} else {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function<bool()> & is_alive) {
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive);
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, res.is_alive);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
return server_task::get_list_id(tasks);
|
||||
}, [is_openai, &ctx_server, &res, body](const std::unordered_set<int> & task_ids, httplib::DataSink & sink) {
|
||||
bool error = false;
|
||||
json responses = json::array();
|
||||
// get the result
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector<server_task_result> & results) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
for (const auto & res : results) {
|
||||
responses.push_back(res.data);
|
||||
}
|
||||
}, [&res, &error](json error_data) {
|
||||
res_error(res, /* sink= */ nullptr, error_data);
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses)
|
||||
: responses[0];
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
|
||||
res_ok(res, &sink, root);
|
||||
});
|
||||
// write JSON response
|
||||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses)
|
||||
: responses[0];
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params.reranking) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
const json body = json::parse(req.body);
|
||||
|
@ -3204,17 +3170,17 @@ int main(int argc, char ** argv) {
|
|||
if (body.count("query") == 1) {
|
||||
query = body.at("query");
|
||||
if (!query.is_string()) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
||||
if (documents.empty()) {
|
||||
res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3231,7 +3197,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; });
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, res.is_alive);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3243,7 +3209,7 @@ int main(int argc, char ** argv) {
|
|||
responses.push_back(res.data);
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, /* sink= */ nullptr, error_data);
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
});
|
||||
}
|
||||
|
@ -3254,7 +3220,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// write JSON response
|
||||
json root = format_response_rerank(body, responses);
|
||||
res_ok(res, /* sink= */ nullptr, root);
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||
|
@ -3267,7 +3233,7 @@ int main(int argc, char ** argv) {
|
|||
{"scale", lora.scale},
|
||||
});
|
||||
}
|
||||
res_ok(res, /* sink= */ nullptr, result);
|
||||
res_ok(res, result);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
|
@ -3299,7 +3265,7 @@ int main(int argc, char ** argv) {
|
|||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
res_ok(res, /* sink= */ nullptr, result.data);
|
||||
res_ok(res, result.data);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
|
@ -3454,4 +3420,4 @@ int main(int argc, char ** argv) {
|
|||
t.join();
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue