From 234ab58af18d035a8f6ab50113669f514aa2ce00 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Mar 2024 10:36:39 +0200 Subject: [PATCH] server : rename server structs --- examples/server/server.cpp | 398 +++++++++++++++++++------------------ 1 file changed, 200 insertions(+), 198 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 835e1f858..e1adbcee4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -34,20 +34,19 @@ bool server_verbose = false; bool server_log_json = true; enum stop_type { - STOP_FULL, - STOP_PARTIAL, + STOP_TYPE_FULL, + STOP_TYPE_PARTIAL, }; -// TODO: can become bool if we can't find use of more states enum slot_state { - IDLE, - PROCESSING, + SLOT_STATE_IDLE, + SLOT_STATE_PROCESSING, }; enum slot_command { - NONE, - LOAD_PROMPT, - RELEASE, + SLOT_COMMAND_NONE, + SLOT_COMMAND_LOAD_PROMPT, + SLOT_COMMAND_RELEASE, }; enum server_state { @@ -56,26 +55,26 @@ enum server_state { SERVER_STATE_ERROR // An error occurred, load_model failed }; -enum task_type { - TASK_TYPE_COMPLETION, - TASK_TYPE_CANCEL, - TASK_TYPE_NEXT_RESPONSE, - TASK_TYPE_METRICS +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS }; -struct task_server { - int id = -1; // to be filled by llama_server_queue +struct server_task { + int id = -1; // to be filled by server_queue int id_multi = -1; int id_target = -1; - task_type type; + server_task_type type; json data; bool infill = false; bool embedding = false; }; -struct task_result { +struct server_task_result { int id = -1; int id_multi = -1; @@ -85,11 +84,11 @@ struct task_result { bool error; }; -struct task_multi { +struct server_task_multi { int id = -1; std::set subtasks_remaining; - std::vector results; + std::vector results; }; struct slot_params { @@ -130,8 +129,8 @@ struct server_slot { struct slot_params params; - slot_state state = IDLE; - slot_command command = NONE; + slot_state state = SLOT_STATE_IDLE; + slot_command command = SLOT_COMMAND_NONE; // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -223,24 +222,24 @@ struct server_slot { } bool available() const { - return state == IDLE && command == NONE; + return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; } bool is_processing() const { - return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; + return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; } void add_token_string(const completion_token_output & token) { - if (command == RELEASE) { + if (command == SLOT_COMMAND_RELEASE) { return; } generated_token_probs.push_back(token); } void release() { - if (state == PROCESSING) { + if (state == SLOT_STATE_PROCESSING) { t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - command = RELEASE; + command = SLOT_COMMAND_RELEASE; } } @@ -264,7 +263,7 @@ struct server_slot { for (const std::string & word : params.antiprompt) { size_t pos; - if (type == STOP_FULL) { + if (type == STOP_TYPE_FULL) { const size_t tmp = word.size() + last_token_size; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; @@ -274,7 +273,7 @@ struct server_slot { } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_FULL) { + if (type == STOP_TYPE_FULL) { stopped_word = true; stopping_word = word; has_next_token = false; @@ -363,25 +362,26 @@ struct server_metrics { } }; -struct llama_server_queue { +struct server_queue { int id = 0; bool running; // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - std::vector queue_multitasks; + std::vector queue_tasks; + std::vector queue_tasks_deferred; + + std::vector queue_multitasks; std::mutex mutex_tasks; std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_run_slots; + std::function callback_new_task; + std::function callback_finish_multitask; + std::function callback_run_slots; // Add a new task to the end of the queue - int post(task_server task) { + int post(server_task task) { std::unique_lock lock(mutex_tasks); if (task.id == -1) { task.id = id++; @@ -393,7 +393,7 @@ struct llama_server_queue { } // Add a new task, but defer until one slot is available - void defer(task_server task) { + void defer(server_task task) { std::unique_lock lock(mutex_tasks); queue_tasks_deferred.push_back(std::move(task)); } @@ -407,12 +407,12 @@ struct llama_server_queue { } // Register function to process a new task - void on_new_task(std::function callback) { + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) { + void on_finish_multitask(std::function callback) { callback_finish_multitask = std::move(callback); } @@ -458,7 +458,7 @@ struct llama_server_queue { lock.unlock(); break; } - task_server task = queue_tasks.front(); + server_task task = queue_tasks.front(); queue_tasks.erase(queue_tasks.begin()); lock.unlock(); LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); @@ -472,7 +472,7 @@ struct llama_server_queue { while (queue_iterator != queue_multitasks.end()) { if (queue_iterator->subtasks_remaining.empty()) { // all subtasks done == multitask is done - task_multi current_multitask = *queue_iterator; + server_task_multi current_multitask = *queue_iterator; callback_finish_multitask(current_multitask); // remove this multitask queue_iterator = queue_multitasks.erase(queue_iterator); @@ -506,17 +506,17 @@ struct llama_server_queue { // functions to manage multitasks // - // add a multitask by specifying the id of all subtask (subtask is a task_server) - void add_multitask(int id_multi, std::vector& sub_ids) { + // add a multitask by specifying the id of all subtask (subtask is a server_task) + void add_multitask(int id_multi, std::vector & sub_ids) { std::lock_guard lock(mutex_tasks); - task_multi multi; + server_task_multi multi; multi.id = id_multi; std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); queue_multitasks.push_back(multi); } // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int id_multi, int id_sub, task_result& result) { + void update_multitask(int id_multi, int id_sub, server_task_result & result) { std::lock_guard lock(mutex_tasks); for (auto & multitask : queue_multitasks) { if (multitask.id == id_multi) { @@ -527,15 +527,15 @@ struct llama_server_queue { } }; -struct llama_server_response { - typedef std::function callback_multitask_t; +struct server_response { + typedef std::function callback_multitask_t; callback_multitask_t callback_update_multitask; // for keeping track of all tasks waiting for the result std::set waiting_task_ids; // the main result queue - std::vector queue_results; + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; @@ -557,7 +557,7 @@ struct llama_server_response { } // This function blocks the thread until there is a response for this id_task - task_result recv(int id_task) { + server_task_result recv(int id_task) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ @@ -567,7 +567,7 @@ struct llama_server_response { for (int i = 0; i < (int) queue_results.size(); i++) { if (queue_results[i].id == id_task) { assert(queue_results[i].id_multi == -1); - task_result res = queue_results[i]; + server_task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; } @@ -583,7 +583,7 @@ struct llama_server_response { } // Send a new result to a waiting id_task - void send(task_result result) { + void send(server_task_result result) { LOG_VERBOSE("send new result", {{"id_task", result.id}}); std::unique_lock lock(mutex_results); @@ -606,7 +606,7 @@ struct llama_server_response { } }; -struct llama_server_context { +struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; @@ -632,12 +632,12 @@ struct llama_server_context { std::vector slots; json default_generation_settings_for_props; - llama_server_queue queue_tasks; - llama_server_response queue_results; + server_queue queue_tasks; + server_response queue_results; server_metrics metrics; - ~llama_server_context() { + ~server_context() { if (ctx) { llama_free(ctx); ctx = nullptr; @@ -956,7 +956,7 @@ struct llama_server_context { llama_set_rng_seed(ctx, slot.params.seed); } - slot.command = LOAD_PROMPT; + slot.command = SLOT_COMMAND_LOAD_PROMPT; slot.prompt_tokens.clear(); LOG_INFO("slot is processing task", { @@ -1081,7 +1081,7 @@ struct llama_server_context { const std::string str_test = slot.generated_text.substr(pos); bool is_stop_full = false; - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_FULL); + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); if (stop_pos != std::string::npos) { is_stop_full = true; slot.generated_text.erase( @@ -1090,7 +1090,7 @@ struct llama_server_context { pos = std::min(slot.n_sent_text, slot.generated_text.size()); } else { is_stop_full = false; - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL); + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); } // check if there is any token to predict @@ -1191,10 +1191,10 @@ struct llama_server_context { }; } - void send_error(const task_server & task, const std::string & error) { + void send_error(const server_task & task, const std::string & error) { LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); - task_result res; + server_task_result res; res.id = task.id; res.id_multi = task.id_multi; res.stop = false; @@ -1205,7 +1205,7 @@ struct llama_server_context { } void send_partial_response(server_slot & slot, completion_token_output tkn) { - task_result res; + server_task_result res; res.id = slot.id_task; res.id_multi = slot.id_multi; res.error = false; @@ -1242,7 +1242,7 @@ struct llama_server_context { } void send_final_response(const server_slot & slot) { - task_result res; + server_task_result res; res.id = slot.id_task; res.id_multi = slot.id_multi; res.error = false; @@ -1291,7 +1291,7 @@ struct llama_server_context { } void send_embedding(const server_slot & slot, const llama_batch & batch) { - task_result res; + server_task_result res; res.id = slot.id_task; res.id_multi = slot.id_multi; res.error = false; @@ -1331,14 +1331,14 @@ struct llama_server_context { } void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) { - task_server task; + server_task task; task.id = id_task; task.id_multi = id_multi; task.id_target = 0; task.data = std::move(data); task.infill = infill; task.embedding = embedding; - task.type = TASK_TYPE_COMPLETION; + task.type = SERVER_TASK_TYPE_COMPLETION; // when a completion task's prompt array is not a singleton, we split it into multiple requests // otherwise, it's a single-prompt task, we actually queue it @@ -1368,14 +1368,14 @@ struct llama_server_context { } void request_cancel(int id_task) { - task_server task; - task.type = TASK_TYPE_CANCEL; + server_task task; + task.type = SERVER_TASK_TYPE_CANCEL; task.id_target = id_task; queue_tasks.post(task); } - void split_multiprompt_task(int id_multi, const task_server & multiprompt_task) { + void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) { const int prompt_count = multiprompt_task.data.at("prompt").size(); if (prompt_count <= 1) { send_error(multiprompt_task, "error while handling multiple prompts"); @@ -1401,9 +1401,9 @@ struct llama_server_context { } } - void process_single_task(const task_server & task) { + void process_single_task(const server_task & task) { switch (task.type) { - case TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_COMPLETION: { server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); if (slot == nullptr) { @@ -1435,7 +1435,7 @@ struct llama_server_context { break; } } break; - case TASK_TYPE_CANCEL: + case SERVER_TASK_TYPE_CANCEL: { // release slot linked with the task id for (auto & slot : slots) { @@ -1445,11 +1445,11 @@ struct llama_server_context { } } } break; - case TASK_TYPE_NEXT_RESPONSE: + case SERVER_TASK_TYPE_NEXT_RESPONSE: { // do nothing } break; - case TASK_TYPE_METRICS: + case SERVER_TASK_TYPE_METRICS: { json slots_data = json::array(); @@ -1472,7 +1472,7 @@ struct llama_server_context { {"stopping_word", slot.stopping_word}, }; - if (slot_data["state"] == IDLE) { + if (slot_data["state"] == SLOT_STATE_IDLE) { n_idle_slots++; } else { n_processing_slots++; @@ -1493,7 +1493,7 @@ struct llama_server_context { {"slots", slots_data} }); - task_result res; + server_task_result res; res.id = task.id; res.id_multi = task.id_multi; res.stop = true; @@ -1523,9 +1523,9 @@ struct llama_server_context { } } - void on_finish_multitask(const task_multi & multitask) { + void on_finish_multitask(const server_task_multi & multitask) { // all subtasks done == multitask is done - task_result result; + server_task_result result; result.id = multitask.id; result.stop = true; result.error = false; @@ -1550,9 +1550,9 @@ struct llama_server_context { // release slots for (auto & slot : slots) { - if (slot.command == RELEASE) { - slot.state = IDLE; - slot.command = NONE; + if (slot.command == SLOT_COMMAND_RELEASE) { + slot.state = SLOT_STATE_IDLE; + slot.command = SLOT_COMMAND_NONE; slot.t_last_used = ggml_time_us(); LOG_INFO("slot released", { @@ -1574,7 +1574,7 @@ struct llama_server_context { bool all_idle = true; for (auto & slot : slots) { - if (slot.state != IDLE || slot.command != NONE) { + if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) { all_idle = false; break; } @@ -1593,8 +1593,8 @@ struct llama_server_context { { LOG_VERBOSE("posting NEXT_RESPONSE", {}); - task_server task; - task.type = TASK_TYPE_NEXT_RESPONSE; + server_task task; + task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; task.id_target = -1; queue_tasks.post(task); @@ -1642,7 +1642,7 @@ struct llama_server_context { // decode any currently ongoing sequences for (auto & slot : slots) { - if (slot.state == IDLE) { + if (slot.state == SLOT_STATE_IDLE) { continue; } @@ -1681,9 +1681,9 @@ struct llama_server_context { // empty prompt passed -> release the slot and send empty response // note: infill mode allows empty prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill) { - slot.state = PROCESSING; - slot.command = NONE; + if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT && !has_prompt && !slot.infill) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); send_final_response(slot); @@ -1691,7 +1691,7 @@ struct llama_server_context { } // need process the prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT) { + if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { auto & prompt_tokens = slot.prompt_tokens; if (prompt_tokens.empty()) { @@ -1734,8 +1734,8 @@ struct llama_server_context { if (slot.embedding) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_batch) { - slot.state = PROCESSING; - slot.command = NONE; + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); send_final_response(slot); @@ -1857,8 +1857,8 @@ struct llama_server_context { // entire prompt has been processed - start decoding new tokens if (slot.n_past == slot.n_prompt_tokens) { - slot.state = PROCESSING; - slot.command = NONE; + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; GGML_ASSERT(batch.n_tokens > 0); @@ -1961,7 +1961,7 @@ struct llama_server_context { } for (auto & slot : slots) { - if (slot.state != PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; } @@ -2574,12 +2574,12 @@ int main(int argc, char ** argv) { server_params sparams; // struct that contains llama context and inference - llama_server_context llama; + server_context ctx_server; server_params_parse(argc, argv, sparams, params); if (!sparams.system_prompt.empty()) { - llama.system_prompt_set(json::parse(sparams.system_prompt)); + ctx_server.system_prompt_set(json::parse(sparams.system_prompt)); } if (params.model_alias == "unknown") { @@ -2621,17 +2621,17 @@ int main(int argc, char ** argv) { case SERVER_STATE_READY: { // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.type = SERVER_TASK_TYPE_METRICS; task.id_target = -1; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); const int n_idle_slots = result.data["idle"]; const int n_processing_slots = result.data["processing"]; @@ -2673,18 +2673,18 @@ int main(int argc, char ** argv) { if (sparams.slots_endpoint) { svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); task.id_multi = -1; task.id_target = -1; - task.type = TASK_TYPE_METRICS; + task.type = SERVER_TASK_TYPE_METRICS; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); res.set_content(result.data["slots"].dump(), "application/json"); res.status = 200; // HTTP OK @@ -2694,18 +2694,18 @@ int main(int argc, char ** argv) { if (sparams.metrics_endpoint) { svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); task.id_multi = -1; task.id_target = -1; - task.type = TASK_TYPE_METRICS; + task.type = SERVER_TASK_TYPE_METRICS; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); json data = result.data; @@ -2831,20 +2831,20 @@ int main(int argc, char ** argv) { } // load the model - if (!llama.load_model(params)) { + if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); return 1; } else { - llama.initialize(); + ctx_server.initialize(); state.store(SERVER_STATE_READY); } LOG_INFO("model loaded", {}); - const auto model_meta = llama.model_meta(); + const auto model_meta = ctx_server.model_meta(); if (sparams.chat_template.empty()) { // custom chat template is not supplied - if (!llama.validate_model_chat_template()) { + if (!ctx_server.validate_model_chat_template()) { LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); sparams.chat_template = "chatml"; } @@ -2901,19 +2901,19 @@ int main(int argc, char ** argv) { return false; }); - svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response & res) { + svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { - { "user_name", llama.name_user.c_str() }, - { "assistant_name", llama.name_assistant.c_str() }, - { "default_generation_settings", llama.default_generation_settings_for_props }, - { "total_slots", llama.params.n_parallel } + { "user_name", ctx_server.name_user.c_str() }, + { "assistant_name", ctx_server.name_assistant.c_str() }, + { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "total_slots", ctx_server.params.n_parallel } }; res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr.Post("/completion", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -2921,13 +2921,13 @@ int main(int argc, char ** argv) { json data = json::parse(req.body); - const int id_task = llama.queue_tasks.get_new_id(); + const int id_task = ctx_server.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, data, false, false); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, false, false); if (!json_value(data, "stream", false)) { - task_result result = llama.queue_results.recv(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { @@ -2935,11 +2935,11 @@ int main(int argc, char ** argv) { res.set_content(result.data["content"], "text/plain; charset=utf-8"); } - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); } else { - const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { while (true) { - task_result result = llama.queue_results.recv(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error) { const std::string str = "data: " + @@ -2951,7 +2951,7 @@ int main(int argc, char ** argv) { }); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } @@ -2969,7 +2969,7 @@ int main(int argc, char ** argv) { }); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } @@ -2977,16 +2977,16 @@ int main(int argc, char ** argv) { } } - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); sink.done(); return true; }; - auto on_complete = [id_task, &llama] (bool) { + auto on_complete = [id_task, &ctx_server] (bool) { // cancel - llama.request_cancel(id_task); - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -3012,21 +3012,21 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), "application/json; charset=utf-8"); }); - const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { + const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); + json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); - const int id_task = llama.queue_tasks.get_new_id(); + const int id_task = ctx_server.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, data, false, false); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, false, false); if (!json_value(data, "stream", false)) { - task_result result = llama.queue_results.recv(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { json result_oai = format_final_response_oaicompat(data, result.data); @@ -3036,11 +3036,11 @@ int main(int argc, char ** argv) { res.status = 500; res.set_content(result.data["content"], "text/plain; charset=utf-8"); } - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); } else { - const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { while (true) { - task_result result = llama.queue_results.recv(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error) { std::vector result_array = format_partial_response_oaicompat(result.data); @@ -3052,7 +3052,7 @@ int main(int argc, char ** argv) { "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } } @@ -3067,21 +3067,21 @@ int main(int argc, char ** argv) { "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } break; } } sink.done(); - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); return true; }; - auto on_complete = [id_task, &llama](bool) { + auto on_complete = [id_task, &ctx_server](bool) { // cancel request - llama.request_cancel(id_task); - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -3091,7 +3091,7 @@ int main(int argc, char ** argv) { svr.Post("/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3099,13 +3099,13 @@ int main(int argc, char ** argv) { json data = json::parse(req.body); - const int id_task = llama.queue_tasks.get_new_id(); + const int id_task = ctx_server.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, data, true, false); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, true, false); if (!json_value(data, "stream", false)) { - task_result result = llama.queue_results.recv(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { @@ -3113,11 +3113,11 @@ int main(int argc, char ** argv) { res.set_content(result.data["content"], "text/plain; charset=utf-8"); } - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); } else { - const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { while (true) { - task_result result = llama.queue_results.recv(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error) { const std::string str = "data: " + @@ -3129,7 +3129,7 @@ int main(int argc, char ** argv) { }); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } @@ -3141,14 +3141,14 @@ int main(int argc, char ** argv) { } } - llama.queue_results.remove_waiting_task_id(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); sink.done(); return true; }; - auto on_complete = [id_task, &llama] (bool) { - llama.request_cancel(id_task); + auto on_complete = [id_task, &ctx_server] (bool) { + ctx_server.request_cancel(id_task); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -3159,33 +3159,33 @@ int main(int argc, char ** argv) { return res.set_content("", "application/json; charset=utf-8"); }); - svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) { + svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; if (body.count("content") != 0) { - tokens = llama.tokenize(body["content"], false); + tokens = ctx_server.tokenize(body["content"], false); } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) { + svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) { const std::vector tokens = body["tokens"]; - content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend()); + content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); } const json data = format_detokenized_response(content); return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/embedding", [¶ms, &llama](const httplib::Request & req, httplib::Response & res) { + svr.Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3203,20 +3203,20 @@ int main(int argc, char ** argv) { } // create and queue the task - const int id_task = llama.queue_tasks.get_new_id(); + const int id_task = ctx_server.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true); // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); // send the result return res.set_content(result.data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/v1/embeddings", [¶ms, &llama](const httplib::Request & req, httplib::Response & res) { + svr.Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3234,14 +3234,14 @@ int main(int argc, char ** argv) { int i = 0; for (const json & elem : prompt) { - const int id_task = llama.queue_tasks.get_new_id(); + const int id_task = ctx_server.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true); // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); json embedding = json{ {"embedding", json_value(result.data, "embedding", json::array())}, @@ -3261,14 +3261,14 @@ int main(int argc, char ** argv) { } // create and queue the task - const int id_task = llama.queue_tasks.get_new_id(); + const int id_task = ctx_server.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); json data = json::array({json{ {"embedding", json_value(result.data, "embedding", json::array())}, @@ -3301,22 +3301,22 @@ int main(int argc, char ** argv) { return 0; }); - llama.queue_tasks.on_new_task(std::bind( - &llama_server_context::process_single_task, &llama, std::placeholders::_1)); - llama.queue_tasks.on_finish_multitask(std::bind( - &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); - llama.queue_tasks.on_run_slots(std::bind( - &llama_server_context::update_slots, &llama)); - llama.queue_results.on_multitask_update(std::bind( - &llama_server_queue::update_multitask, - &llama.queue_tasks, + ctx_server.queue_tasks.on_new_task(std::bind( + &server_context::process_single_task, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_run_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3 )); shutdown_handler = [&](int) { - llama.queue_tasks.terminate(); + ctx_server.queue_tasks.terminate(); }; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) @@ -3331,7 +3331,9 @@ int main(int argc, char ** argv) { }; SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - llama.queue_tasks.start_loop(); + + ctx_server.queue_tasks.start_loop(); + svr.stop(); t.join();