refactor completions handler
This commit is contained in:
parent
012d8d8cc0
commit
4a5dbd85b5
2 changed files with 189 additions and 270 deletions
|
@ -5,13 +5,6 @@
|
|||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
||||
#endif
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
#include "httplib.h"
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
@ -39,12 +32,12 @@
|
|||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <set>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
@ -97,6 +90,15 @@ struct server_task {
|
|||
json data;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
|
||||
// utility function
|
||||
static std::vector<int> get_list_id(std::vector<server_task> tasks) {
|
||||
std::vector<int> ids(tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
ids[i] = tasks[i].id;
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
|
@ -108,13 +110,6 @@ struct server_task_result {
|
|||
bool error;
|
||||
};
|
||||
|
||||
struct server_task_multi {
|
||||
int id = -1;
|
||||
|
||||
std::set<int> subtasks_remaining;
|
||||
std::vector<server_task_result> results;
|
||||
};
|
||||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
||||
|
@ -133,6 +128,9 @@ struct server_slot {
|
|||
int id;
|
||||
int id_task = -1;
|
||||
|
||||
// the index relative to completion multi-task request
|
||||
size_t index = 0;
|
||||
|
||||
struct slot_params params;
|
||||
|
||||
slot_state state = SLOT_STATE_IDLE;
|
||||
|
@ -388,15 +386,12 @@ struct server_queue {
|
|||
std::vector<server_task> queue_tasks;
|
||||
std::vector<server_task> queue_tasks_deferred;
|
||||
|
||||
std::vector<server_task_multi> queue_multitasks;
|
||||
|
||||
std::mutex mutex_tasks;
|
||||
std::condition_variable condition_tasks;
|
||||
|
||||
// callback functions
|
||||
std::function<void(server_task &)> callback_new_task;
|
||||
std::function<void(server_task_multi &)> callback_finish_multitask;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
std::function<void(server_task&)> callback_new_task;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task) {
|
||||
|
@ -437,11 +432,6 @@ struct server_queue {
|
|||
callback_new_task = std::move(callback);
|
||||
}
|
||||
|
||||
// Register function to process a multitask when it is finished
|
||||
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
|
||||
callback_finish_multitask = std::move(callback);
|
||||
}
|
||||
|
||||
// Register the function to be called when all slots data is ready to be processed
|
||||
void on_update_slots(std::function<void(void)> callback) {
|
||||
callback_update_slots = std::move(callback);
|
||||
|
@ -490,22 +480,6 @@ struct server_queue {
|
|||
callback_new_task(task);
|
||||
}
|
||||
|
||||
LOG_VERBOSE("update_multitasks", {});
|
||||
|
||||
// check if we have any finished multitasks
|
||||
auto queue_iterator = queue_multitasks.begin();
|
||||
while (queue_iterator != queue_multitasks.end()) {
|
||||
if (queue_iterator->subtasks_remaining.empty()) {
|
||||
// all subtasks done == multitask is done
|
||||
server_task_multi current_multitask = *queue_iterator;
|
||||
callback_finish_multitask(current_multitask);
|
||||
// remove this multitask
|
||||
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||
} else {
|
||||
++queue_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
LOG_VERBOSE("callback_update_slots", {});
|
||||
|
||||
|
@ -530,7 +504,7 @@ struct server_queue {
|
|||
|
||||
struct server_response {
|
||||
// for keeping track of all tasks waiting for the result
|
||||
std::set<int> waiting_task_ids;
|
||||
std::unordered_set<int> waiting_task_ids;
|
||||
|
||||
// the main result queue
|
||||
std::vector<server_task_result> queue_results;
|
||||
|
@ -1387,7 +1361,8 @@ struct server_context {
|
|||
{"content", tkn.text_to_send},
|
||||
{"stop", false},
|
||||
{"id_slot", slot.id},
|
||||
{"multimodal", false}
|
||||
{"multimodal", false},
|
||||
{"index", slot.index},
|
||||
};
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
|
@ -1434,7 +1409,8 @@ struct server_context {
|
|||
{"stopped_limit", slot.stopped_limit},
|
||||
{"stopping_word", slot.stopping_word},
|
||||
{"tokens_cached", slot.n_past},
|
||||
{"timings", slot.get_formated_timings()}
|
||||
{"timings", slot.get_formated_timings()},
|
||||
{"index", slot.index},
|
||||
};
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
|
@ -1500,6 +1476,7 @@ struct server_context {
|
|||
|
||||
res.data = json {
|
||||
{"embedding", embd_res},
|
||||
{"index", slot.index},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1507,10 +1484,10 @@ struct server_context {
|
|||
}
|
||||
|
||||
//
|
||||
// functions to create new task(s)
|
||||
// Functions to create new task(s) and receive result(s)
|
||||
//
|
||||
|
||||
std::vector<server_task> request_completion(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||
server_task task;
|
||||
|
@ -1535,12 +1512,16 @@ struct server_context {
|
|||
|
||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
||||
data["index"] = 0;
|
||||
create_task(data, false, nullptr);
|
||||
}
|
||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||
else if (prompt.is_array()) {
|
||||
for (auto const & e : prompt) {
|
||||
std::vector<json> prompts = prompt;
|
||||
for (size_t i = 0; i < prompts.size(); i++) {
|
||||
const auto & e = prompts[i];
|
||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||
data["index"] = i;
|
||||
create_task(data, true, e);
|
||||
} else {
|
||||
throw std::runtime_error(error_msg);
|
||||
|
@ -1555,17 +1536,61 @@ struct server_context {
|
|||
return tasks;
|
||||
}
|
||||
|
||||
void request_cancel(std::vector<server_task> & tasks) {
|
||||
for (const auto & t : tasks) {
|
||||
void cancel_tasks(std::unordered_set<int> & id_tasks) {
|
||||
for (const auto & id_task : id_tasks) {
|
||||
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_CANCEL;
|
||||
task.id_target = t.id;
|
||||
task.id_target = id_task;
|
||||
queue_tasks.post(task);
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
}
|
||||
}
|
||||
|
||||
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
|
||||
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
|
||||
std::vector<server_task_result> results(id_set.size());
|
||||
for (size_t i = 0; i < id_set.size(); i++) {
|
||||
server_task_result result = queue_results.recv(id_set);
|
||||
|
||||
if (result.error) {
|
||||
error_handler(result.data);
|
||||
cancel_tasks(id_set);
|
||||
break;
|
||||
}
|
||||
|
||||
size_t idx = result.data["index"];
|
||||
results[idx] = result;
|
||||
}
|
||||
result_handler(results);
|
||||
}
|
||||
|
||||
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
|
||||
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
|
||||
size_t n_finished = 0;
|
||||
while (true) {
|
||||
server_task_result result = queue_results.recv(id_set);
|
||||
if (!result_handler(result)) {
|
||||
cancel_tasks(id_set);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.error) {
|
||||
error_handler(result.data);
|
||||
cancel_tasks(id_set);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.stop) {
|
||||
if (++n_finished == id_set.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// functions to process the task
|
||||
// Functions to process the task
|
||||
//
|
||||
|
||||
void process_single_task(const server_task & task) {
|
||||
|
@ -1614,6 +1639,7 @@ struct server_context {
|
|||
|
||||
slot->id_task = task.id;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
LOG_ERROR("error while launching slot", task.data);
|
||||
|
@ -1841,26 +1867,6 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
void on_finish_multitask(const server_task_multi & multitask) {
|
||||
// all subtasks done == multitask is done
|
||||
server_task_result result;
|
||||
result.id = multitask.id;
|
||||
result.stop = true;
|
||||
result.error = false;
|
||||
|
||||
// collect json results into one json result
|
||||
std::vector<json> result_jsons;
|
||||
for (const auto & subres : multitask.results) {
|
||||
result_jsons.push_back(subres.data);
|
||||
result.error = result.error && subres.error;
|
||||
}
|
||||
result.data = json {
|
||||
{ "results", result_jsons }
|
||||
};
|
||||
|
||||
queue_results.send(result);
|
||||
}
|
||||
|
||||
void update_slots() {
|
||||
if (system_need_update) {
|
||||
system_prompt_update();
|
||||
|
@ -2556,6 +2562,11 @@ int main(int argc, char ** argv) {
|
|||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
auto res_ok = [](httplib::Response & res, json data) {
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
res.status = 200;
|
||||
};
|
||||
|
||||
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
std::string message;
|
||||
try {
|
||||
|
@ -2603,7 +2614,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||
static const std::set<std::string> protected_endpoints = {
|
||||
static const std::unordered_set<std::string> protected_endpoints = {
|
||||
"/props",
|
||||
"/completion",
|
||||
"/completions",
|
||||
|
@ -2932,81 +2943,106 @@ int main(int argc, char ** argv) {
|
|||
res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
bool stream = json_value(data, "stream", false);
|
||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
} else {
|
||||
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
|
||||
LOG_VERBOSE("data stream", {
|
||||
{ "to_send", str }
|
||||
});
|
||||
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result.stop) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const std::string str =
|
||||
"error: " +
|
||||
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
|
||||
LOG_VERBOSE("data stream", {
|
||||
{ "to_send", str }
|
||||
});
|
||||
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
|
||||
break;
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, results[0].data);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
}
|
||||
res_ok(res, arr);
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
}, [&](json error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) mutable {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
}, [&](json error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
sink.done();
|
||||
|
||||
return true;
|
||||
};
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
|
||||
}
|
||||
};
|
||||
|
||||
auto on_complete = [tasks, &ctx_server](bool) mutable {
|
||||
// cancel
|
||||
ctx_server.request_cancel(tasks);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
||||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id);
|
||||
res_ok(res, result_oai);
|
||||
}, [&](json error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
||||
for (auto & event_data : result_array) {
|
||||
if (event_data.empty()) {
|
||||
continue; // skip the stop token
|
||||
}
|
||||
if (!server_sent_event(sink, "data", event_data)) {
|
||||
return false; // connection is closed
|
||||
}
|
||||
}
|
||||
return true; // ok
|
||||
}, [&](json error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
std::string done_event = "[DONE]"; // OAI-compat behavior
|
||||
sink.write(done_event.c_str(), done_event.size());
|
||||
sink.done();
|
||||
return true;
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -3027,142 +3063,6 @@ int main(int argc, char ** argv) {
|
|||
res.set_content(models.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
|
||||
if (!result.error && result.stop) {
|
||||
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
|
||||
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
} else {
|
||||
const auto chunked_content_provider = [tasks, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error) {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
||||
|
||||
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
||||
if (!it->empty()) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (result.stop) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const std::string str =
|
||||
"error: " +
|
||||
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
sink.done();
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [tasks, &ctx_server](bool) mutable {
|
||||
// cancel request
|
||||
ctx_server.request_cancel(tasks);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_INFILL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
} else {
|
||||
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
|
||||
LOG_VERBOSE("data stream", {
|
||||
{ "to_send", str }
|
||||
});
|
||||
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result.stop) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
sink.done();
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [tasks, &ctx_server](bool) mutable {
|
||||
ctx_server.request_cancel(tasks);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
|
@ -3208,7 +3108,7 @@ int main(int argc, char ** argv) {
|
|||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.request_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_INFILL);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3407,8 +3307,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
||||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
|
||||
|
|
|
@ -3,6 +3,14 @@
|
|||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
||||
#endif
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
#include "httplib.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
@ -355,6 +363,19 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
|||
return out;
|
||||
}
|
||||
|
||||
static bool server_sent_event(httplib::DataSink & sink, const char * event, json & data) {
|
||||
const std::string str =
|
||||
std::string(event) + ": " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
|
||||
LOG_VERBOSE("data stream", {
|
||||
{ "to_send", str }
|
||||
});
|
||||
|
||||
return sink.write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue