From 12829b2e64afdf113f627064a6d57d3f64b31e19 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 21 Jan 2024 14:45:28 +0100 Subject: [PATCH] server: add llama_server_response_event --- Makefile | 2 +- examples/server/CMakeLists.txt | 2 +- examples/server/oai.hpp | 14 ++++ examples/server/server.cpp | 125 ++++++++++++--------------------- examples/server/utils.hpp | 104 +++++++++++++++++++++++++-- 5 files changed, 161 insertions(+), 86 deletions(-) create mode 100644 examples/server/oai.hpp diff --git a/Makefile b/Makefile index f7f24ba9e..f9a933b25 100644 --- a/Makefile +++ b/Makefile @@ -619,7 +619,7 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) +server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual gguf: examples/gguf/gguf.cpp ggml.o $(OBJS) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index b3772081f..cc13b2d63 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,7 +1,7 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) +add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp new file mode 100644 index 000000000..b06b7efb4 --- /dev/null +++ b/examples/server/oai.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "json.hpp" +#include "utils.hpp" + +using json = nlohmann::json; + diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c79ef7915..a853d4a69 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -43,7 +43,7 @@ struct server_params int32_t write_timeout = 600; }; -static bool server_verbose = false; +bool server_verbose = false; json oaicompat_completion_params_parse(const json &body); std::string format_chatml(std::vector messages); @@ -279,7 +279,7 @@ struct llama_client_slot } void release() { - if (state == IDLE || state == PROCESSING) + if (state == PROCESSING) { t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; command = RELEASE; @@ -342,12 +342,9 @@ struct llama_server_context std::vector slots; llama_server_queue queue_tasks; - std::vector queue_results; + llama_server_response_event queue_results; std::vector queue_multitasks; - std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks - // std::condition_variable condition_tasks; - std::mutex mutex_results; - std::condition_variable condition_results; + std::mutex mutex_multitasks; ~llama_server_context() { @@ -968,20 +965,18 @@ struct llama_server_context void send_error(task_server& task, const std::string &error) { LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); - std::unique_lock lock(mutex_results); task_result res; res.id = task.id; res.multitask_id = task.multitask_id; res.stop = false; res.error = true; res.result_json = { { "content", error } }; - queue_results.push_back(res); - condition_results.notify_all(); + queue_results.send(res); } - void add_multi_task(int id, std::vector& sub_ids) + void add_multitask(int id, std::vector& sub_ids) { - std::lock_guard lock(mutex_tasks); + std::lock_guard lock(mutex_multitasks); task_multi multi; multi.id = id; std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); @@ -989,9 +984,9 @@ struct llama_server_context // TODO @ngxson : Do we need to notify the queue_tasks? } - void update_multi_task(int multitask_id, int subtask_id, task_result& result) + void update_multitask(int multitask_id, int subtask_id, task_result& result) { - std::lock_guard lock(mutex_tasks); + std::lock_guard lock(mutex_multitasks); for (auto& multitask : queue_multitasks) { if (multitask.id == multitask_id) @@ -1046,7 +1041,6 @@ struct llama_server_context void send_partial_response(llama_client_slot &slot, completion_token_output tkn) { - std::unique_lock lock(mutex_results); task_result res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1081,13 +1075,11 @@ struct llama_server_context res.result_json["model"] = slot.oaicompat_model; } - queue_results.push_back(res); - condition_results.notify_all(); + queue_results.send(res); } void send_final_response(llama_client_slot &slot) { - std::unique_lock lock(mutex_results); task_result res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1136,22 +1128,17 @@ struct llama_server_context res.result_json["model"] = slot.oaicompat_model; } - queue_results.push_back(res); - condition_results.notify_all(); - - // done with results, unlock - lock.unlock(); + queue_results.send(res); // parent multitask, if any, needs to be updated if (slot.multitask_id != -1) { - update_multi_task(slot.multitask_id, slot.task_id, res); + update_multitask(slot.multitask_id, slot.task_id, res); } } void send_embedding(llama_client_slot &slot) { - std::unique_lock lock(mutex_results); task_result res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1178,13 +1165,12 @@ struct llama_server_context {"embedding", embedding }, }; } - queue_results.push_back(res); - condition_results.notify_all(); + queue_results.send(res); } int request_completion(json data, bool infill, bool embedding, int multitask_id) { - std::unique_lock lock(mutex_tasks); + std::unique_lock lock(mutex_multitasks); task_server task; task.target_id = 0; task.data = std::move(data); @@ -1204,40 +1190,6 @@ struct llama_server_context return queue_tasks.post(task); } - task_result next_result(int task_id) - { - LOG_TEE("next_result %i \n", task_id); - while (true) - { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - return !queue_results.empty(); - }); - - for (int i = 0; i < (int) queue_results.size(); i++) - { - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (queue_results[i].multitask_id == task_id) - { - update_multi_task(task_id, queue_results[i].id, queue_results[i]); - queue_results.erase(queue_results.begin() + i); - continue; - } - - if (queue_results[i].id == task_id) - { - assert(queue_results[i].multitask_id == -1); - task_result res = queue_results[i]; - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // never reached - //return task_result{-1, false, false, {}}; - } - // for multiple images processing bool ingest_images(llama_client_slot &slot, int n_batch) { @@ -1331,7 +1283,7 @@ struct llama_server_context } // queue up the multitask so we can track its subtask progression - add_multi_task(multitask_id, subtask_ids); + add_multitask(multitask_id, subtask_ids); return multitask_id; } @@ -1344,8 +1296,8 @@ struct llama_server_context if (slot == nullptr) { // if no slot is available, we defer this task for processing later - // deferred_tasks.push_back(task); - LOG_INFO("no slot", {}); + LOG_TEE("no slot\n"); + queue_tasks.defer(task); break; } @@ -1417,12 +1369,7 @@ struct llama_server_context aggregate_result.error = aggregate_result.error && subres.error; } aggregate_result.result_json = json{ "results", result_jsons }; - - agg_results.push_back(aggregate_result); - - condition_results.notify_all(); - queue_iterator = queue_multitasks.erase(queue_iterator); } else @@ -1432,8 +1379,9 @@ struct llama_server_context } // copy aggregate results of complete multi-tasks to the results queue - std::lock_guard lock_results(mutex_results); - queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end()); + for (auto& res : agg_results) { + queue_results.send(res); + } } bool update_slots() { @@ -2845,9 +2793,10 @@ int main(int argc, char **argv) } json data = json::parse(req.body); const int task_id = llama.request_completion(data, false, false, -1); + llama.queue_results.add_waiting_task_id(task_id); if (!json_value(data, "stream", false)) { std::string completion_text; - task_result result = llama.next_result(task_id); + task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } @@ -2855,6 +2804,7 @@ int main(int argc, char **argv) { res.status = 404; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); + llama.queue_results.remove_waiting_task_id(task_id); return; } } else { @@ -2862,7 +2812,7 @@ int main(int argc, char **argv) { while (true) { - task_result result = llama.next_result(task_id); + task_result result = llama.queue_results.recv(task_id); if (!result.error) { const std::string str = "data: " + @@ -2873,6 +2823,7 @@ int main(int argc, char **argv) }); if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(task_id); return false; } if (result.stop) { @@ -2888,12 +2839,14 @@ int main(int argc, char **argv) }); if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(task_id); return false; } break; } } sink.done(); + llama.queue_results.remove_waiting_task_id(task_id); return true; }; @@ -2901,6 +2854,7 @@ int main(int argc, char **argv) { // cancel llama.request_cancel(task_id); + llama.queue_results.remove_waiting_task_id(task_id); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -2938,10 +2892,11 @@ int main(int argc, char **argv) json data = oaicompat_completion_params_parse(json::parse(req.body)); const int task_id = llama.request_completion(data, false, false, -1); + llama.queue_results.add_waiting_task_id(task_id); if (!json_value(data, "stream", false)) { std::string completion_text; - task_result result = llama.next_result(task_id); + task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { json oaicompat_result = format_final_response_oaicompat(data, result); @@ -2952,12 +2907,13 @@ int main(int argc, char **argv) } else { res.status = 500; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); + llama.queue_results.remove_waiting_task_id(task_id); return; } } else { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { while (true) { - task_result llama_result = llama.next_result(task_id); + task_result llama_result = llama.queue_results.recv(task_id); if (!llama_result.error) { std::vector result_array = format_partial_response_oaicompat( llama_result); @@ -2970,6 +2926,7 @@ int main(int argc, char **argv) "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(task_id); return false; } } @@ -2985,18 +2942,21 @@ int main(int argc, char **argv) "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(task_id); return false; } break; } } sink.done(); + llama.queue_results.remove_waiting_task_id(task_id); return true; }; auto on_complete = [task_id, &llama](bool) { // cancel request llama.request_cancel(task_id); + llama.queue_results.remove_waiting_task_id(task_id); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -3013,7 +2973,7 @@ int main(int argc, char **argv) const int task_id = llama.request_completion(data, true, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; - task_result result = llama.next_result(task_id); + task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); @@ -3028,7 +2988,7 @@ int main(int argc, char **argv) const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { while (true) { - task_result result = llama.next_result(task_id); + task_result result = llama.queue_results.recv(task_id); if (!result.error) { const std::string str = "data: " + @@ -3128,7 +3088,7 @@ int main(int argc, char **argv) } const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); - task_result result = llama.next_result(task_id); + task_result result = llama.queue_results.recv(task_id); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); @@ -3149,6 +3109,13 @@ int main(int argc, char **argv) llama.queue_tasks.on_all_tasks_finished(std::bind( &llama_server_context::run_on_all_tasks_finished, &llama)); llama.queue_tasks.start_loop(); + llama.queue_results.on_multitask_update(std::bind( + &llama_server_context::update_multitask, + &llama, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); t.join(); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index cf1925dde..a38698762 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "json.hpp" @@ -12,6 +13,8 @@ using json = nlohmann::json; +extern bool server_verbose; + #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 #endif @@ -160,12 +163,12 @@ struct llama_server_queue { int id = 0; std::mutex mutex_tasks; std::vector queue_tasks; + std::vector queue_tasks_deferred; std::condition_variable condition_tasks; std::function callback_new_task; std::function callback_all_task_finished; int post(T task) { - LOG_INFO("post", {}); std::unique_lock lock(mutex_tasks); task.id = id++; queue_tasks.push_back(std::move(task)); @@ -173,6 +176,11 @@ struct llama_server_queue { return task.id; } + void defer(T task) { + std::unique_lock lock(mutex_tasks); + queue_tasks_deferred.push_back(std::move(task)); + } + int get_next_id() { std::unique_lock lock(mutex_tasks); return id++; @@ -189,7 +197,7 @@ struct llama_server_queue { void start_loop() { while (true) { // new task arrived - LOG_INFO("have new task", {}); + LOG_VERBOSE("have new task", {}); { while (true) { @@ -201,13 +209,27 @@ struct llama_server_queue { task_server task = queue_tasks.front(); queue_tasks.erase(queue_tasks.begin()); lock.unlock(); - LOG_INFO("callback_new_task", {}); + LOG_VERBOSE("callback_new_task", {}); callback_new_task(task); } - LOG_INFO("callback_all_task_finished", {}); + // move deferred tasks back to main loop + { + std::unique_lock lock(mutex_tasks); + //queue_tasks.insert( + // queue_tasks.end(), + // std::make_move_iterator(queue_tasks_deferred.begin()), + // std::make_move_iterator(queue_tasks_deferred.end()) + //); + for (auto & task : queue_tasks_deferred) { + queue_tasks.push_back(task); + } + queue_tasks_deferred.clear(); + lock.unlock(); + } + LOG_VERBOSE("callback_all_task_finished", {}); callback_all_task_finished(); } - LOG_INFO("wait for new task", {}); + LOG_VERBOSE("wait for new task", {}); // wait for new task { std::unique_lock lock(mutex_tasks); @@ -221,6 +243,78 @@ struct llama_server_queue { } }; +struct llama_server_response_event { + typedef std::function callback_multitask_t; + std::vector queue_results; + std::mutex mutex_task_ids; + std::set waiting_task_ids; + std::mutex mutex_results; + std::condition_variable condition_results; + callback_multitask_t callback_update_multitask; + + void add_waiting_task_id(int task_id) { + std::unique_lock lock(mutex_task_ids); + waiting_task_ids.insert(task_id); + } + + void remove_waiting_task_id(int task_id) { + std::unique_lock lock(mutex_task_ids); + waiting_task_ids.erase(task_id); + } + + task_result recv(int task_id) { + while (true) + { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + return !queue_results.empty(); + }); + LOG_VERBOSE("condition_results unblock", {}); + + for (int i = 0; i < (int) queue_results.size(); i++) + { + if (queue_results[i].id == task_id) + { + assert(queue_results[i].multitask_id == -1); + task_result res = queue_results[i]; + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + void on_multitask_update(callback_multitask_t callback) { + callback_update_multitask = callback; + } + + void send(task_result result) { + std::unique_lock lock(mutex_results); + std::unique_lock lock1(mutex_task_ids); + LOG_VERBOSE("send new result", {}); + for (auto& task_id : waiting_task_ids) { + LOG_TEE("waiting task id %i \n", task_id); + // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result + if (result.multitask_id == task_id) + { + LOG_VERBOSE("callback_update_multitask", {}); + callback_update_multitask(task_id, result.id, result); + continue; + } + + if (result.id == task_id) + { + LOG_VERBOSE("queue_results.push_back", {}); + queue_results.push_back(result); + condition_results.notify_one(); + return; + } + } + } +}; + // // base64 utils (TODO: move to common in the future) //