refactor completions handler

This commit is contained in:
Xuan Son Nguyen 2024-09-02 00:09:05 +02:00
parent 012d8d8cc0
commit 4a5dbd85b5
2 changed files with 189 additions and 270 deletions

View file

@ -5,13 +5,6 @@
#include "llama.h"
#include "grammar-parser.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
@ -39,12 +32,12 @@
#include <chrono>
#include <condition_variable>
#include <cstddef>
#include <set>
#include <mutex>
#include <thread>
#include <signal.h>
#include <memory>
#include <unordered_set>
#include <unordered_map>
using json = nlohmann::ordered_json;
@ -97,6 +90,15 @@ struct server_task {
json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
// utility function
static std::vector<int> get_list_id(std::vector<server_task> tasks) {
std::vector<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids[i] = tasks[i].id;
}
return ids;
}
};
struct server_task_result {
@ -108,13 +110,6 @@ struct server_task_result {
bool error;
};
struct server_task_multi {
int id = -1;
std::set<int> subtasks_remaining;
std::vector<server_task_result> results;
};
struct slot_params {
bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
@ -133,6 +128,9 @@ struct server_slot {
int id;
int id_task = -1;
// the index relative to completion multi-task request
size_t index = 0;
struct slot_params params;
slot_state state = SLOT_STATE_IDLE;
@ -388,15 +386,12 @@ struct server_queue {
std::vector<server_task> queue_tasks;
std::vector<server_task> queue_tasks_deferred;
std::vector<server_task_multi> queue_multitasks;
std::mutex mutex_tasks;
std::condition_variable condition_tasks;
// callback functions
std::function<void(server_task &)> callback_new_task;
std::function<void(server_task_multi &)> callback_finish_multitask;
std::function<void(void)> callback_update_slots;
std::function<void(server_task&)> callback_new_task;
std::function<void(void)> callback_update_slots;
// Add a new task to the end of the queue
int post(server_task task) {
@ -437,11 +432,6 @@ struct server_queue {
callback_new_task = std::move(callback);
}
// Register function to process a multitask when it is finished
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
callback_finish_multitask = std::move(callback);
}
// Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback);
@ -490,22 +480,6 @@ struct server_queue {
callback_new_task(task);
}
LOG_VERBOSE("update_multitasks", {});
// check if we have any finished multitasks
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end()) {
if (queue_iterator->subtasks_remaining.empty()) {
// all subtasks done == multitask is done
server_task_multi current_multitask = *queue_iterator;
callback_finish_multitask(current_multitask);
// remove this multitask
queue_iterator = queue_multitasks.erase(queue_iterator);
} else {
++queue_iterator;
}
}
// all tasks in the current loop is processed, slots data is now ready
LOG_VERBOSE("callback_update_slots", {});
@ -530,7 +504,7 @@ struct server_queue {
struct server_response {
// for keeping track of all tasks waiting for the result
std::set<int> waiting_task_ids;
std::unordered_set<int> waiting_task_ids;
// the main result queue
std::vector<server_task_result> queue_results;
@ -1387,7 +1361,8 @@ struct server_context {
{"content", tkn.text_to_send},
{"stop", false},
{"id_slot", slot.id},
{"multimodal", false}
{"multimodal", false},
{"index", slot.index},
};
if (slot.sparams.n_probs > 0) {
@ -1434,7 +1409,8 @@ struct server_context {
{"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word},
{"tokens_cached", slot.n_past},
{"timings", slot.get_formated_timings()}
{"timings", slot.get_formated_timings()},
{"index", slot.index},
};
if (slot.sparams.n_probs > 0) {
@ -1500,6 +1476,7 @@ struct server_context {
res.data = json {
{"embedding", embd_res},
{"index", slot.index},
};
}
@ -1507,10 +1484,10 @@ struct server_context {
}
//
// functions to create new task(s)
// Functions to create new task(s) and receive result(s)
//
std::vector<server_task> request_completion(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
server_task task;
@ -1535,12 +1512,16 @@ struct server_context {
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
data["index"] = 0;
create_task(data, false, nullptr);
}
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
else if (prompt.is_array()) {
for (auto const & e : prompt) {
std::vector<json> prompts = prompt;
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) {
data["index"] = i;
create_task(data, true, e);
} else {
throw std::runtime_error(error_msg);
@ -1555,17 +1536,61 @@ struct server_context {
return tasks;
}
void request_cancel(std::vector<server_task> & tasks) {
for (const auto & t : tasks) {
void cancel_tasks(std::unordered_set<int> & id_tasks) {
for (const auto & id_task : id_tasks) {
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
server_task task;
task.type = SERVER_TASK_TYPE_CANCEL;
task.id_target = t.id;
task.id_target = id_task;
queue_tasks.post(task);
queue_results.remove_waiting_task_id(id_task);
}
}
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
std::vector<server_task_result> results(id_set.size());
for (size_t i = 0; i < id_set.size(); i++) {
server_task_result result = queue_results.recv(id_set);
if (result.error) {
error_handler(result.data);
cancel_tasks(id_set);
break;
}
size_t idx = result.data["index"];
results[idx] = result;
}
result_handler(results);
}
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
size_t n_finished = 0;
while (true) {
server_task_result result = queue_results.recv(id_set);
if (!result_handler(result)) {
cancel_tasks(id_set);
break;
}
if (result.error) {
error_handler(result.data);
cancel_tasks(id_set);
break;
}
if (result.stop) {
if (++n_finished == id_set.size()) {
break;
}
}
}
}
//
// functions to process the task
// Functions to process the task
//
void process_single_task(const server_task & task) {
@ -1614,6 +1639,7 @@ struct server_context {
slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type;
slot->index = json_value(task.data, "index", 0);
if (!launch_slot_with_task(*slot, task)) {
LOG_ERROR("error while launching slot", task.data);
@ -1841,26 +1867,6 @@ struct server_context {
}
}
void on_finish_multitask(const server_task_multi & multitask) {
// all subtasks done == multitask is done
server_task_result result;
result.id = multitask.id;
result.stop = true;
result.error = false;
// collect json results into one json result
std::vector<json> result_jsons;
for (const auto & subres : multitask.results) {
result_jsons.push_back(subres.data);
result.error = result.error && subres.error;
}
result.data = json {
{ "results", result_jsons }
};
queue_results.send(result);
}
void update_slots() {
if (system_need_update) {
system_prompt_update();
@ -2556,6 +2562,11 @@ int main(int argc, char ** argv) {
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response & res, json data) {
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
res.status = 200;
};
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
std::string message;
try {
@ -2603,7 +2614,7 @@ int main(int argc, char ** argv) {
auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::set<std::string> protected_endpoints = {
static const std::unordered_set<std::string> protected_endpoints = {
"/props",
"/completion",
"/completions",
@ -2932,81 +2943,106 @@ int main(int argc, char ** argv) {
res.set_content(data.dump(), MIMETYPE_JSON);
};
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = json::parse(req.body);
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else {
res_error(res, result.data);
}
bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.remove_waiting_tasks(tasks);
} else {
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
while (true) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error) {
const std::string str =
"data: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
if (result.stop) {
break;
}
} else {
const std::string str =
"error: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
break;
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
res_ok(res, arr);
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
}, [&](json error_data) {
res_error(res, error_data);
});
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) mutable {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
return server_sent_event(sink, "data", result.data);
}, [&](json error_data) {
server_sent_event(sink, "error", error_data);
});
sink.done();
return true;
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
}
};
auto on_complete = [tasks, &ctx_server](bool) mutable {
// cancel
ctx_server.request_cancel(tasks);
ctx_server.queue_results.remove_waiting_tasks(tasks);
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
};
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
};
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id);
res_ok(res, result_oai);
}, [&](json error_data) {
res_error(res, error_data);
});
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
}
if (!server_sent_event(sink, "data", event_data)) {
return false; // connection is closed
}
}
return true; // ok
}, [&](json error_data) {
server_sent_event(sink, "error", error_data);
});
std::string done_event = "[DONE]"; // OAI-compat behavior
sink.write(done_event.c_str(), done_event.size());
sink.done();
return true;
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
}
};
@ -3027,142 +3063,6 @@ int main(int argc, char ** argv) {
res.set_content(models.dump(), MIMETYPE_JSON);
};
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
const auto completion_id = gen_chatcmplid();
if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error && result.stop) {
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else {
res_error(res, result.data);
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
} else {
const auto chunked_content_provider = [tasks, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
while (true) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
if (!it->empty()) {
const std::string str =
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
}
}
if (result.stop) {
break;
}
} else {
const std::string str =
"error: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
break;
}
}
sink.done();
ctx_server.queue_results.remove_waiting_tasks(tasks);
return true;
};
auto on_complete = [tasks, &ctx_server](bool) mutable {
// cancel request
ctx_server.request_cancel(tasks);
ctx_server.queue_results.remove_waiting_tasks(tasks);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = json::parse(req.body);
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_INFILL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else {
res_error(res, result.data);
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
} else {
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
while (true) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error) {
const std::string str =
"data: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
if (result.stop) {
break;
}
} else {
break;
}
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
sink.done();
return true;
};
auto on_complete = [tasks, &ctx_server](bool) mutable {
ctx_server.request_cancel(tasks);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
@ -3208,7 +3108,7 @@ int main(int argc, char ** argv) {
// create and queue the task
json responses;
{
std::vector<server_task> tasks = ctx_server.request_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_INFILL);
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@ -3407,8 +3307,6 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_finish_multitask(std::bind(
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_update_slots(std::bind(
&server_context::update_slots, &ctx_server));

View file

@ -3,6 +3,14 @@
#include "llama.h"
#include "common.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
@ -355,6 +363,19 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
return out;
}
static bool server_sent_event(httplib::DataSink & sink, const char * event, json & data) {
const std::string str =
std::string(event) + ": " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
return sink.write(str.c_str(), str.size());
}
//
// OAI utils
//