server : remove multitask from server_task
This commit is contained in:
parent
8f1d81a0b6
commit
012d8d8cc0
2 changed files with 165 additions and 186 deletions
|
@ -44,6 +44,7 @@
|
|||
#include <thread>
|
||||
#include <signal.h>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
@ -82,21 +83,24 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum server_task_cmpl_type {
|
||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_multi = -1;
|
||||
int id_target = -1;
|
||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||
|
||||
server_task_type type;
|
||||
json data;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_multi = -1;
|
||||
|
||||
json data;
|
||||
|
||||
|
@ -128,7 +132,6 @@ struct slot_params {
|
|||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
int id_multi = -1;
|
||||
|
||||
struct slot_params params;
|
||||
|
||||
|
@ -158,8 +161,7 @@ struct server_slot {
|
|||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
bool has_next_token = true;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
|
@ -204,7 +206,7 @@ struct server_slot {
|
|||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
ga_i = 0;
|
||||
n_past_se = 0;
|
||||
|
||||
|
@ -408,6 +410,14 @@ struct server_queue {
|
|||
return task.id;
|
||||
}
|
||||
|
||||
// multi-task version of post()
|
||||
int post(std::vector<server_task> & tasks) {
|
||||
for (auto & task : tasks) {
|
||||
post(task);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Add a new task, but defer until one slot is available
|
||||
void defer(server_task task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
|
@ -516,36 +526,9 @@ struct server_queue {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// functions to manage multitasks
|
||||
//
|
||||
|
||||
// add a multitask by specifying the id of all subtask (subtask is a server_task)
|
||||
void add_multitask(int id_multi, std::vector<int> & sub_ids) {
|
||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||
server_task_multi multi;
|
||||
multi.id = id_multi;
|
||||
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
||||
queue_multitasks.push_back(multi);
|
||||
}
|
||||
|
||||
// updatethe remaining subtasks, while appending results to multitask
|
||||
void update_multitask(int id_multi, int id_sub, server_task_result & result) {
|
||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||
for (auto & multitask : queue_multitasks) {
|
||||
if (multitask.id == id_multi) {
|
||||
multitask.subtasks_remaining.erase(id_sub);
|
||||
multitask.results.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct server_response {
|
||||
typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
|
||||
callback_multitask_t callback_update_multitask;
|
||||
|
||||
// for keeping track of all tasks waiting for the result
|
||||
std::set<int> waiting_task_ids;
|
||||
|
||||
|
@ -563,6 +546,12 @@ struct server_response {
|
|||
waiting_task_ids.insert(id_task);
|
||||
}
|
||||
|
||||
void add_waiting_tasks(std::vector<server_task> & tasks) {
|
||||
for (const auto & t : tasks) {
|
||||
add_waiting_task_id(t.id);
|
||||
}
|
||||
}
|
||||
|
||||
// when the request is finished, we can remove task associated with it
|
||||
void remove_waiting_task_id(int id_task) {
|
||||
LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
|
||||
|
@ -571,8 +560,14 @@ struct server_response {
|
|||
waiting_task_ids.erase(id_task);
|
||||
}
|
||||
|
||||
// This function blocks the thread until there is a response for this id_task
|
||||
server_task_result recv(int id_task) {
|
||||
void remove_waiting_tasks(std::vector<server_task> & tasks) {
|
||||
for (const auto & t : tasks) {
|
||||
remove_waiting_task_id(t.id);
|
||||
}
|
||||
}
|
||||
|
||||
// This function blocks the thread until there is a response for one of the id_tasks
|
||||
server_task_result recv(std::unordered_set<int> & id_tasks) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&]{
|
||||
|
@ -580,8 +575,7 @@ struct server_response {
|
|||
});
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
if (queue_results[i].id == id_task) {
|
||||
assert(queue_results[i].id_multi == -1);
|
||||
if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
|
||||
server_task_result res = queue_results[i];
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
|
@ -592,9 +586,20 @@ struct server_response {
|
|||
// should never reach here
|
||||
}
|
||||
|
||||
// Register the function to update multitask
|
||||
void on_multitask_update(callback_multitask_t callback) {
|
||||
callback_update_multitask = std::move(callback);
|
||||
// single-task version of recv()
|
||||
server_task_result recv(int id_task) {
|
||||
std::unordered_set<int> id_tasks = {id_task};
|
||||
return recv(id_tasks);
|
||||
}
|
||||
|
||||
// multi-task version of recv()
|
||||
server_task_result recv(std::vector<server_task> & tasks) {
|
||||
std::unordered_set<int> id_tasks;
|
||||
id_tasks.reserve(tasks.size());
|
||||
for (const auto & t : tasks) {
|
||||
id_tasks.insert(t.id);
|
||||
}
|
||||
return recv(id_tasks);
|
||||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
|
@ -603,14 +608,6 @@ struct server_response {
|
|||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
for (const auto & id_task : waiting_task_ids) {
|
||||
// LOG_TEE("waiting task id %i \n", id_task);
|
||||
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
||||
if (result.id_multi == id_task) {
|
||||
LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
|
||||
callback_update_multitask(id_task, result.id, result);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (result.id == id_task) {
|
||||
LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
|
||||
queue_results.push_back(result);
|
||||
|
@ -966,7 +963,7 @@ struct server_context {
|
|||
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
||||
|
||||
// get prompt
|
||||
if (!task.infill) {
|
||||
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
@ -1359,23 +1356,21 @@ struct server_context {
|
|||
}
|
||||
|
||||
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
send_error(task.id, task.id_multi, error, type);
|
||||
send_error(task.id, error, type);
|
||||
}
|
||||
|
||||
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
send_error(slot.id_task, slot.id_multi, error, type);
|
||||
send_error(slot.id_task, error, type);
|
||||
}
|
||||
|
||||
void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
LOG_ERROR("task error", {
|
||||
{"id_multi", id_multi},
|
||||
{"id_task", id_task},
|
||||
{"error", error},
|
||||
});
|
||||
|
||||
server_task_result res;
|
||||
res.id = id_task;
|
||||
res.id_multi = id_multi;
|
||||
res.stop = false;
|
||||
res.error = true;
|
||||
res.data = format_error_response(error, type);
|
||||
|
@ -1386,7 +1381,6 @@ struct server_context {
|
|||
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = false;
|
||||
res.data = json {
|
||||
|
@ -1423,7 +1417,6 @@ struct server_context {
|
|||
void send_final_response(const server_slot & slot) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
res.data = json {
|
||||
|
@ -1473,7 +1466,6 @@ struct server_context {
|
|||
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
|
||||
|
@ -1514,76 +1506,67 @@ struct server_context {
|
|||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) {
|
||||
server_task task;
|
||||
task.id = id_task;
|
||||
task.id_multi = id_multi;
|
||||
task.id_target = 0;
|
||||
task.data = std::move(data);
|
||||
task.infill = infill;
|
||||
task.embedding = embedding;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
||||
// otherwise, it's a single-prompt task, we actually queue it
|
||||
// if there's numbers in the prompt array it will be treated as an array of tokens
|
||||
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
||||
bool numbers = false;
|
||||
for (const auto & e : task.data.at("prompt")) {
|
||||
if (e.is_number()) {
|
||||
numbers = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
|
||||
// it will completely stall the server. I don't know where the bug for this is.
|
||||
//
|
||||
// if there are numbers, it needs to be treated like a single prompt,
|
||||
// queue_tasks handles a mix of strings and numbers just fine.
|
||||
if (numbers) {
|
||||
queue_tasks.post(task);
|
||||
// functions to create new task(s)
|
||||
//
|
||||
|
||||
std::vector<server_task> request_completion(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||
server_task task;
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.cmpl_type = cmpl_type;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
if (replace_prompt) {
|
||||
task.data = task_data;
|
||||
task.data["prompt"] = prompt;
|
||||
} else {
|
||||
split_multiprompt_task(id_task, task);
|
||||
}
|
||||
} else {
|
||||
queue_tasks.post(task);
|
||||
task.data = std::move(task_data);
|
||||
}
|
||||
tasks.push_back(std::move(task));
|
||||
};
|
||||
|
||||
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
|
||||
if (!data.contains("prompt")) {
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
void request_cancel(int id_task) {
|
||||
json prompt = data.at("prompt");
|
||||
|
||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
||||
create_task(data, false, nullptr);
|
||||
}
|
||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||
else if (prompt.is_array()) {
|
||||
for (auto const & e : prompt) {
|
||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||
create_task(data, true, e);
|
||||
} else {
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
// invalid case
|
||||
else {
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
return tasks;
|
||||
}
|
||||
|
||||
void request_cancel(std::vector<server_task> & tasks) {
|
||||
for (const auto & t : tasks) {
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_CANCEL;
|
||||
task.id_target = id_task;
|
||||
|
||||
task.id_target = t.id;
|
||||
queue_tasks.post(task);
|
||||
}
|
||||
|
||||
void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) {
|
||||
const int prompt_count = multiprompt_task.data.at("prompt").size();
|
||||
if (prompt_count <= 1) {
|
||||
send_error(multiprompt_task, "error while handling multiple prompts");
|
||||
return;
|
||||
}
|
||||
|
||||
// generate all the ID for subtask
|
||||
std::vector<int> subtask_ids(prompt_count);
|
||||
for (int i = 0; i < prompt_count; i++) {
|
||||
subtask_ids[i] = queue_tasks.get_new_id();
|
||||
}
|
||||
|
||||
// queue up the multitask so we can track its subtask progression
|
||||
queue_tasks.add_multitask(id_multi, subtask_ids);
|
||||
|
||||
// add subtasks
|
||||
for (int i = 0; i < prompt_count; i++) {
|
||||
json subtask_data = multiprompt_task.data;
|
||||
subtask_data["prompt"] = subtask_data.at("prompt")[i];
|
||||
|
||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
|
||||
}
|
||||
}
|
||||
//
|
||||
// functions to process the task
|
||||
//
|
||||
|
||||
void process_single_task(const server_task & task) {
|
||||
switch (task.type) {
|
||||
|
@ -1630,9 +1613,7 @@ struct server_context {
|
|||
slot->reset();
|
||||
|
||||
slot->id_task = task.id;
|
||||
slot->id_multi = task.id_multi;
|
||||
slot->infill = task.infill;
|
||||
slot->embedding = task.embedding;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
LOG_ERROR("error while launching slot", task.data);
|
||||
|
@ -1699,7 +1680,6 @@ struct server_context {
|
|||
|
||||
server_task_result res;
|
||||
res.id = task.id;
|
||||
res.id_multi = task.id_multi;
|
||||
res.stop = true;
|
||||
res.error = false;
|
||||
res.data = {
|
||||
|
@ -2038,7 +2018,7 @@ struct server_context {
|
|||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
if (slot.infill) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
|
||||
const bool add_bos = llama_add_bos_token(model);
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||
|
@ -2101,7 +2081,7 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (slot.embedding) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
|
@ -2184,7 +2164,7 @@ struct server_context {
|
|||
slot.n_prompt_tokens_processed = 0;
|
||||
}
|
||||
|
||||
if (slot.embedding) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
|
@ -2192,7 +2172,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
bool slot_type = slot.embedding ? 1 : 0;
|
||||
bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
} else if (batch_type != slot_type) {
|
||||
|
@ -2385,7 +2365,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// prompt evaluated for embedding
|
||||
if (slot.embedding) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
|
@ -2707,8 +2687,6 @@ int main(int argc, char ** argv) {
|
|||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.id_multi = -1;
|
||||
task.id_target = -1;
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
|
@ -2740,7 +2718,6 @@ int main(int argc, char ** argv) {
|
|||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.id_multi = -1;
|
||||
task.id_target = -1;
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
task.data.push_back({{"reset_bucket", true}});
|
||||
|
@ -2963,24 +2940,23 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, data, false, false);
|
||||
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
} else {
|
||||
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
|
@ -2992,7 +2968,7 @@ int main(int argc, char ** argv) {
|
|||
});
|
||||
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -3010,7 +2986,7 @@ int main(int argc, char ** argv) {
|
|||
});
|
||||
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -3018,16 +2994,16 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
sink.done();
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [id_task, &ctx_server] (bool) {
|
||||
auto on_complete = [tasks, &ctx_server](bool) mutable {
|
||||
// cancel
|
||||
ctx_server.request_cancel(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.request_cancel(tasks);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
|
@ -3058,14 +3034,13 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, data, false, false);
|
||||
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
|
||||
if (!result.error && result.stop) {
|
||||
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
|
||||
|
@ -3074,11 +3049,11 @@ int main(int argc, char ** argv) {
|
|||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
} else {
|
||||
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
||||
const auto chunked_content_provider = [tasks, &ctx_server, completion_id](size_t, httplib::DataSink & sink) mutable {
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error) {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
||||
|
||||
|
@ -3090,7 +3065,7 @@ int main(int argc, char ** argv) {
|
|||
"\n\n";
|
||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -3105,21 +3080,21 @@ int main(int argc, char ** argv) {
|
|||
"\n\n";
|
||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
sink.done();
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [id_task, &ctx_server](bool) {
|
||||
auto on_complete = [tasks, &ctx_server](bool) mutable {
|
||||
// cancel request
|
||||
ctx_server.request_cancel(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.request_cancel(tasks);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
|
@ -3134,24 +3109,23 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, data, true, false);
|
||||
std::vector<server_task> tasks = ctx_server.request_completion(data, SERVER_TASK_CMPL_TYPE_INFILL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
} else {
|
||||
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
const auto chunked_content_provider = [tasks, &ctx_server](size_t, httplib::DataSink & sink) mutable {
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
if (!result.error) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
|
@ -3163,7 +3137,7 @@ int main(int argc, char ** argv) {
|
|||
});
|
||||
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -3175,14 +3149,14 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
sink.done();
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [id_task, &ctx_server] (bool) {
|
||||
ctx_server.request_cancel(id_task);
|
||||
auto on_complete = [tasks, &ctx_server](bool) mutable {
|
||||
ctx_server.request_cancel(tasks);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
|
@ -3234,13 +3208,13 @@ int main(int argc, char ** argv) {
|
|||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
|
||||
std::vector<server_task> tasks = ctx_server.request_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_INFILL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
server_task_result result = ctx_server.queue_results.recv(tasks);
|
||||
ctx_server.queue_results.remove_waiting_tasks(tasks);
|
||||
if (!result.error) {
|
||||
if (result.data.count("results")) {
|
||||
// result for multi-task
|
||||
|
@ -3437,13 +3411,6 @@ int main(int argc, char ** argv) {
|
|||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_results.on_multitask_update(std::bind(
|
||||
&server_queue::update_multitask,
|
||||
&ctx_server.queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
|
|
|
@ -279,6 +279,18 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
|
|||
return std::string::npos;
|
||||
}
|
||||
|
||||
static bool json_is_array_of_numbers(json data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
if (!e.is_number()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
template <class Iter>
|
||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue