refactor completions handler

This commit is contained in:
Xuan Son Nguyen 2024-09-02 00:09:05 +02:00
parent 012d8d8cc0
commit 4a5dbd85b5
2 changed files with 189 additions and 270 deletions

View file

@ -5,13 +5,6 @@
#include "llama.h" #include "llama.h"
#include "grammar-parser.h" #include "grammar-parser.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
@ -39,12 +32,12 @@
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
#include <cstddef> #include <cstddef>
#include <set>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <signal.h> #include <signal.h>
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
#include <unordered_map>
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -97,6 +90,15 @@ struct server_task {
json data; json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
// utility function
static std::vector<int> get_list_id(std::vector<server_task> tasks) {
std::vector<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids[i] = tasks[i].id;
}
return ids;
}
}; };
struct server_task_result { struct server_task_result {
@ -108,13 +110,6 @@ struct server_task_result {
bool error; bool error;
}; };
struct server_task_multi {
int id = -1;
std::set<int> subtasks_remaining;
std::vector<server_task_result> results;
};
struct slot_params { struct slot_params {
bool stream = true; bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
@ -133,6 +128,9 @@ struct server_slot {
int id; int id;
int id_task = -1; int id_task = -1;
// the index relative to completion multi-task request
size_t index = 0;
struct slot_params params; struct slot_params params;
slot_state state = SLOT_STATE_IDLE; slot_state state = SLOT_STATE_IDLE;
@ -388,15 +386,12 @@ struct server_queue {
std::vector<server_task> queue_tasks; std::vector<server_task> queue_tasks;
std::vector<server_task> queue_tasks_deferred; std::vector<server_task> queue_tasks_deferred;
std::vector<server_task_multi> queue_multitasks;
std::mutex mutex_tasks; std::mutex mutex_tasks;
std::condition_variable condition_tasks; std::condition_variable condition_tasks;
// callback functions // callback functions
std::function<void(server_task &)> callback_new_task; std::function<void(server_task&)> callback_new_task;
std::function<void(server_task_multi &)> callback_finish_multitask; std::function<void(void)> callback_update_slots;
std::function<void(void)> callback_update_slots;
// Add a new task to the end of the queue // Add a new task to the end of the queue
int post(server_task task) { int post(server_task task) {
@ -437,11 +432,6 @@ struct server_queue {
callback_new_task = std::move(callback); callback_new_task = std::move(callback);
} }
// Register function to process a multitask when it is finished
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
callback_finish_multitask = std::move(callback);
}
// Register the function to be called when all slots data is ready to be processed // Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback) { void on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback); callback_update_slots = std::move(callback);
@ -490,22 +480,6 @@ struct server_queue {
callback_new_task(task); callback_new_task(task);
} }
LOG_VERBOSE("update_multitasks", {});
// check if we have any finished multitasks
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end()) {
if (queue_iterator->subtasks_remaining.empty()) {
// all subtasks done == multitask is done
server_task_multi current_multitask = *queue_iterator;
callback_finish_multitask(current_multitask);
// remove this multitask
queue_iterator = queue_multitasks.erase(queue_iterator);
} else {
++queue_iterator;
}
}
// all tasks in the current loop is processed, slots data is now ready // all tasks in the current loop is processed, slots data is now ready
LOG_VERBOSE("callback_update_slots", {}); LOG_VERBOSE("callback_update_slots", {});
@ -530,7 +504,7 @@ struct server_queue {
struct server_response { struct server_response {
// for keeping track of all tasks waiting for the result // for keeping track of all tasks waiting for the result
std::set<int> waiting_task_ids; std::unordered_set<int> waiting_task_ids;
// the main result queue // the main result queue
std::vector<server_task_result> queue_results; std::vector<server_task_result> queue_results;
@ -1387,7 +1361,8 @@ struct server_context {
{"content", tkn.text_to_send}, {"content", tkn.text_to_send},
{"stop", false}, {"stop", false},
{"id_slot", slot.id}, {"id_slot", slot.id},
{"multimodal", false} {"multimodal", false},
{"index", slot.index},
}; };
if (slot.sparams.n_probs > 0) { if (slot.sparams.n_probs > 0) {
@ -1434,7 +1409,8 @@ struct server_context {
{"stopped_limit", slot.stopped_limit}, {"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word}, {"stopping_word", slot.stopping_word},
{"tokens_cached", slot.n_past}, {"tokens_cached", slot.n_past},
{"timings", slot.get_formated_timings()} {"timings", slot.get_formated_timings()},
{"index", slot.index},
}; };
if (slot.sparams.n_probs > 0) { if (slot.sparams.n_probs > 0) {
@ -1500,6 +1476,7 @@ struct server_context {
res.data = json { res.data = json {
{"embedding", embd_res}, {"embedding", embd_res},
{"index", slot.index},
}; };
} }
@ -1507,10 +1484,10 @@ struct server_context {
} }
// //
// functions to create new task(s) // Functions to create new task(s) and receive result(s)
// //
std::vector<server_task> request_completion(json data, server_task_cmpl_type cmpl_type) { std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
server_task task; server_task task;
@ -1535,12 +1512,16 @@ struct server_context {
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
if (prompt.is_string() || json_is_array_of_numbers(prompt)) { if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
data["index"] = 0;
create_task(data, false, nullptr); create_task(data, false, nullptr);
} }
// otherwise, it's a multiple-prompt task, we break it into smaller tasks // otherwise, it's a multiple-prompt task, we break it into smaller tasks
else if (prompt.is_array()) { else if (prompt.is_array()) {
for (auto const & e : prompt) { std::vector<json> prompts = prompt;
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) { if (e.is_string() || json_is_array_of_numbers(e)) {
data["index"] = i;
create_task(data, true, e); create_task(data, true, e);
} else { } else {
throw std::runtime_error(error_msg); throw std::runtime_error(error_msg);
@ -1555,17 +1536,61 @@ struct server_context {
return tasks; return tasks;
} }
void request_cancel(std::vector<server_task> & tasks) { void cancel_tasks(std::unordered_set<int> & id_tasks) {
for (const auto & t : tasks) { for (const auto & id_task : id_tasks) {
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
server_task task; server_task task;
task.type = SERVER_TASK_TYPE_CANCEL; task.type = SERVER_TASK_TYPE_CANCEL;
task.id_target = t.id; task.id_target = id_task;
queue_tasks.post(task); queue_tasks.post(task);
queue_results.remove_waiting_task_id(id_task);
}
}
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
std::vector<server_task_result> results(id_set.size());
for (size_t i = 0; i < id_set.size(); i++) {
server_task_result result = queue_results.recv(id_set);
if (result.error) {
error_handler(result.data);
cancel_tasks(id_set);
break;
}
size_t idx = result.data["index"];
results[idx] = result;
}
result_handler(results);
}
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
size_t n_finished = 0;
while (true) {
server_task_result result = queue_results.recv(id_set);
if (!result_handler(result)) {
cancel_tasks(id_set);
break;
}
if (result.error) {
error_handler(result.data);
cancel_tasks(id_set);
break;
}
if (result.stop) {
if (++n_finished == id_set.size()) {
break;
}
}
} }
} }
// //
// functions to process the task // Functions to process the task
// //
void process_single_task(const server_task & task) { void process_single_task(const server_task & task) {
@ -1614,6 +1639,7 @@ struct server_context {
slot->id_task = task.id; slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type; slot->cmpl_type = task.cmpl_type;
slot->index = json_value(task.data, "index", 0);
if (!launch_slot_with_task(*slot, task)) { if (!launch_slot_with_task(*slot, task)) {
LOG_ERROR("error while launching slot", task.data); LOG_ERROR("error while launching slot", task.data);
@ -1841,26 +1867,6 @@ struct server_context {
} }
} }
void on_finish_multitask(const server_task_multi & multitask) {
// all subtasks done == multitask is done
server_task_result result;
result.id = multitask.id;
result.stop = true;
result.error = false;
// collect json results into one json result
std::vector<json> result_jsons;
for (const auto & subres : multitask.results) {
result_jsons.push_back(subres.data);
result.error = result.error && subres.error;
}
result.data = json {
{ "results", result_jsons }
};
queue_results.send(result);
}
void update_slots() { void update_slots() {
if (system_need_update) { if (system_need_update) {
system_prompt_update(); system_prompt_update();
@ -2556,6 +2562,11 @@ int main(int argc, char ** argv) {
res.status = json_value(error_data, "code", 500); res.status = json_value(error_data, "code", 500);
}; };
auto res_ok = [](httplib::Response & res, json data) {
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
res.status = 200;
};
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
std::string message; std::string message;
try { try {
@ -2603,7 +2614,7 @@ int main(int argc, char ** argv) {
auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) { auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"? // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::set<std::string> protected_endpoints = { static const std::unordered_set<std::string> protected_endpoints = {
"/props", "/props",
"/completion", "/completion",
"/completions", "/completions",
@ -2932,81 +2943,106 @@ int main(int argc, char ** argv) {
res.set_content(data.dump(), MIMETYPE_JSON); res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding) { if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
json data = json::parse(req.body); std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type);
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
if (!json_value(data, "stream", false)) { bool stream = json_value(data, "stream", false);
server_task_result result = ctx_server.queue_results.recv(tasks); std::vector<int> task_ids = server_task::get_list_id(tasks);
if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else {
res_error(res, result.data);
}
ctx_server.queue_results.remove_waiting_tasks(tasks); if (!stream) {
} else { ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable { if (results.size() == 1) {
while (true) { // single result
server_task_result result = ctx_server.queue_results.recv(tasks); res_ok(res, results[0].data);
if (!result.error) { } else {
const std::string str = // multiple results (multitask)
"data: " + json arr = json::array();
result.data.dump(-1, ' ', false, json::error_handler_t::replace) + for (const auto & res : results) {
"\n\n"; arr.push_back(res.data);
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
if (result.stop) {
break;
}
} else {
const std::string str =
"error: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
break;
} }
res_ok(res, arr);
} }
}, [&](json error_data) {
ctx_server.queue_results.remove_waiting_tasks(tasks); res_error(res, error_data);
});
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) mutable {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
return server_sent_event(sink, "data", result.data);
}, [&](json error_data) {
server_sent_event(sink, "error", error_data);
});
sink.done(); sink.done();
return true; return true;
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
}
};
auto on_complete = [tasks, &ctx_server](bool) mutable { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
// cancel json data = json::parse(req.body);
ctx_server.request_cancel(tasks); return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
ctx_server.queue_results.remove_waiting_tasks(tasks); };
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
};
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id);
res_ok(res, result_oai);
}, [&](json error_data) {
res_error(res, error_data);
});
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
}
if (!server_sent_event(sink, "data", event_data)) {
return false; // connection is closed
}
}
return true; // ok
}, [&](json error_data) {
server_sent_event(sink, "error", error_data);
});
std::string done_event = "[DONE]"; // OAI-compat behavior
sink.write(done_event.c_str(), done_event.size());
sink.done();
return true;
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }
}; };
@ -3027,142 +3063,6 @@ int main(int argc, char ** argv) {
res.set_content(models.dump(), MIMETYPE_JSON); res.set_content(models.dump(), MIMETYPE_JSON);
}; };
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
const auto completion_id = gen_chatcmplid();
if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error && result.stop) {
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else {
res_error(res, result.data);
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
} else {
const auto chunked_content_provider = [tasks, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
while (true) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
if (!it->empty()) {
const std::string str =
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
}
}
if (result.stop) {
break;
}
} else {
const std::string str =
"error: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
break;
}
}
sink.done();
ctx_server.queue_results.remove_waiting_tasks(tasks);
return true;
};
auto on_complete = [tasks, &ctx_server](bool) mutable {
// cancel request
ctx_server.request_cancel(tasks);
ctx_server.queue_results.remove_waiting_tasks(tasks);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = json::parse(req.body);
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_INFILL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else {
res_error(res, result.data);
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
} else {
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
while (true) {
server_task_result result = ctx_server.queue_results.recv(tasks);
if (!result.error) {
const std::string str =
"data: " +
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.c_str(), str.size())) {
ctx_server.queue_results.remove_waiting_tasks(tasks);
return false;
}
if (result.stop) {
break;
}
} else {
break;
}
}
ctx_server.queue_results.remove_waiting_tasks(tasks);
sink.done();
return true;
};
auto on_complete = [tasks, &ctx_server](bool) mutable {
ctx_server.request_cancel(tasks);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
@ -3208,7 +3108,7 @@ int main(int argc, char ** argv) {
// create and queue the task // create and queue the task
json responses; json responses;
{ {
std::vector<server_task> tasks = ctx_server.request_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_INFILL); std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3407,8 +3307,6 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_finish_multitask(std::bind(
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_update_slots(std::bind( ctx_server.queue_tasks.on_update_slots(std::bind(
&server_context::update_slots, &ctx_server)); &server_context::update_slots, &ctx_server));

View file

@ -3,6 +3,14 @@
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
@ -355,6 +363,19 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
return out; return out;
} }
static bool server_sent_event(httplib::DataSink & sink, const char * event, json & data) {
const std::string str =
std::string(event) + ": " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
return sink.write(str.c_str(), str.size());
}
// //
// OAI utils // OAI utils
// //