From 6e29f4c7257c37b483d88291b03300ad8c776dad Mon Sep 17 00:00:00 2001 From: ngxson Date: Sat, 20 Jan 2024 00:25:20 +0100 Subject: [PATCH] server: add llama_server_queue struct --- Makefile | 2 +- examples/server/CMakeLists.txt | 2 +- examples/server/server.cpp | 367 +++++++-------------------------- examples/server/utils.hpp | 296 ++++++++++++++++++++++++++ 4 files changed, 372 insertions(+), 295 deletions(-) create mode 100644 examples/server/utils.hpp diff --git a/Makefile b/Makefile index a8658a596..f7f24ba9e 100644 --- a/Makefile +++ b/Makefile @@ -619,7 +619,7 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) +server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual gguf: examples/gguf/gguf.cpp ggml.o $(OBJS) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 81709e448..b3772081f 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,7 +1,7 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -add_executable(${TARGET} server.cpp json.hpp httplib.h) +add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0462fbd24..c79ef7915 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" #include "grammar-parser.h" +#include "utils.hpp" #include "../llava/clip.h" @@ -28,10 +29,6 @@ #include #include -#ifndef SERVER_VERBOSE -#define SERVER_VERBOSE 1 -#endif - #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" using json = nlohmann::json; @@ -48,196 +45,9 @@ struct server_params static bool server_verbose = false; -#if SERVER_VERBOSE != 1 -#define LOG_VERBOSE(MSG, ...) -#else -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - if (server_verbose) \ - { \ - server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) - json oaicompat_completion_params_parse(const json &body); std::string format_chatml(std::vector messages); - -// -// base64 utils (TODO: move to common in the future) -// - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -static inline bool is_base64(uint8_t c) -{ - return (isalnum(c) || (c == '+') || (c == '/')); -} - -static std::vector base64_decode(const std::string & encoded_string) -{ - int i = 0; - int j = 0; - int in_ = 0; - - int in_len = encoded_string.size(); - - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; - - std::vector ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) - { - for (i = 0; i <4; i++) - { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) - { - ret.push_back(char_array_3[i]); - } - i = 0; - } - } - - if (i) - { - for (j = i; j <4; j++) - { - char_array_4[j] = 0; - } - - for (j = 0; j <4; j++) - { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; (j < i - 1); j++) - { - ret.push_back(char_array_3[j]); - } - } - - return ret; -} - -// -// parallel -// - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed -}; - -enum task_type { - TASK_TYPE_COMPLETION, - TASK_TYPE_CANCEL, -}; - -struct task_server { - int id; - int target_id; - task_type type; - json data; - bool infill_mode = false; - bool embedding_mode = false; - int multitask_id = -1; -}; - -struct task_result { - int id; - int multitask_id = -1; - bool stop; - bool error; - json result_json; -}; - -struct task_multi { - int id; - std::set subtasks_remaining{}; - std::vector results{}; -}; - -// TODO: can become bool if we can't find use of more states -enum slot_state -{ - IDLE, - PROCESSING, -}; - -enum slot_command -{ - NONE, - LOAD_PROMPT, - RELEASE, -}; - -struct slot_params -{ - bool stream = true; - bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - - uint32_t seed = -1; // RNG seed - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_predict = -1; // new tokens to predict - - std::vector antiprompt; - - json input_prefix; - json input_suffix; -}; - -struct slot_image -{ - int32_t id; - - bool request_encode_image = false; - float * image_embedding = nullptr; - int32_t image_tokens = 0; - - clip_image_u8 * img_data; - - std::string prefix_prompt; // before of this image -}; - -// completion token output with probabilities -struct completion_token_output -{ - struct token_prob - { - llama_token tok; - float prob; - }; - - std::vector probs; - llama_token tok; - std::string text_to_send; -}; - static size_t common_part(const std::vector &a, const std::vector &b) { size_t i; @@ -292,28 +102,6 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) return ret; } -static void server_log(const char *level, const char *function, int line, - const char *message, const nlohmann::ordered_json &extra) -{ - nlohmann::ordered_json log - { - {"timestamp", time(nullptr)}, - {"level", level}, - {"function", function}, - {"line", line}, - {"message", message}, - }; - - if (!extra.empty()) - { - log.merge_patch(extra); - } - - const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); - printf("%.*s\n", (int)str.size(), str.data()); - fflush(stdout); -} - // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { @@ -539,7 +327,6 @@ struct llama_server_context bool all_slots_are_idle = false; bool add_bos_token = true; - int32_t id_gen; int32_t n_ctx; // total context for all clients / slots // system prompt @@ -554,11 +341,11 @@ struct llama_server_context // slots / clients std::vector slots; - std::vector queue_tasks; + llama_server_queue queue_tasks; std::vector queue_results; std::vector queue_multitasks; std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks - std::condition_variable condition_tasks; + // std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results; @@ -619,8 +406,6 @@ struct llama_server_context } void initialize() { - id_gen = 0; - // create slots all_slots_are_idle = true; @@ -1201,7 +986,7 @@ struct llama_server_context multi.id = id; std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); queue_multitasks.push_back(multi); - condition_tasks.notify_one(); + // TODO @ngxson : Do we need to notify the queue_tasks? } void update_multi_task(int multitask_id, int subtask_id, task_result& result) @@ -1213,7 +998,7 @@ struct llama_server_context { multitask.subtasks_remaining.erase(subtask_id); multitask.results.push_back(result); - condition_tasks.notify_one(); + // TODO @ngxson : Do we need to notify the queue_tasks? } } } @@ -1401,7 +1186,6 @@ struct llama_server_context { std::unique_lock lock(mutex_tasks); task_server task; - task.id = id_gen++; task.target_id = 0; task.data = std::move(data); task.infill_mode = infill; @@ -1417,13 +1201,12 @@ struct llama_server_context } // otherwise, it's a single-prompt task, we actually queue it - queue_tasks.push_back(task); - condition_tasks.notify_one(); - return task.id; + return queue_tasks.post(task); } task_result next_result(int task_id) { + LOG_TEE("next_result %i \n", task_id); while (true) { std::unique_lock lock(mutex_results); @@ -1525,13 +1308,10 @@ struct llama_server_context void request_cancel(int task_id) { - std::unique_lock lock(mutex_tasks); task_server task; - task.id = id_gen++; task.type = TASK_TYPE_CANCEL; task.target_id = task_id; - queue_tasks.push_back(task); - condition_tasks.notify_one(); + queue_tasks.post(task); } int split_multiprompt_task(task_server& multiprompt_task) @@ -1539,7 +1319,7 @@ struct llama_server_context int prompt_count = multiprompt_task.data.at("prompt").size(); assert(prompt_count > 1); - int multitask_id = id_gen++; + int multitask_id = queue_tasks.get_next_id(); std::vector subtask_ids(prompt_count); for (int i = 0; i < prompt_count; i++) { @@ -1555,73 +1335,67 @@ struct llama_server_context return multitask_id; } - void process_tasks() + void process_single_task(task_server task) { - std::unique_lock lock(mutex_tasks); - std::vector deferred_tasks; - while (!queue_tasks.empty()) + switch (task.type) { - task_server task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); - switch (task.type) - { - case TASK_TYPE_COMPLETION: { - llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) - { - // if no slot is available, we defer this task for processing later - deferred_tasks.push_back(task); + case TASK_TYPE_COMPLETION: { + llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); + if (slot == nullptr) + { + // if no slot is available, we defer this task for processing later + // deferred_tasks.push_back(task); + LOG_INFO("no slot", {}); + break; + } + + if (task.data.contains("system_prompt")) + { + if (!all_slots_are_idle) { + send_error(task, "system prompt can only be updated when all slots are idle"); break; } + process_system_prompt_data(task.data["system_prompt"]); - if (task.data.contains("system_prompt")) + // reset cache_tokens for all slots + for (llama_client_slot &slot : slots) { - if (!all_slots_are_idle) { - send_error(task, "system prompt can only be updated when all slots are idle"); - break; - } - process_system_prompt_data(task.data["system_prompt"]); - - // reset cache_tokens for all slots - for (llama_client_slot &slot : slots) - { - slot.cache_tokens.clear(); - } + slot.cache_tokens.clear(); } + } - slot->reset(); + slot->reset(); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; + slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; + slot->task_id = task.id; + slot->multitask_id = task.multitask_id; - if (!launch_slot_with_data(slot, task.data)) + if (!launch_slot_with_data(slot, task.data)) + { + // send error result + send_error(task, "internal_error"); + break; + } + } break; + case TASK_TYPE_CANCEL: { // release slot linked with the task id + for (auto & slot : slots) + { + if (slot.task_id == task.target_id) { - // send error result - send_error(task, "internal_error"); + slot.release(); break; } - } break; - case TASK_TYPE_CANCEL: { // release slot linked with the task id - for (auto & slot : slots) - { - if (slot.task_id == task.target_id) - { - slot.release(); - break; - } - } - } break; - } - } - - // add all the deferred tasks back the the queue - for (task_server &task : deferred_tasks) - { - queue_tasks.push_back(task); + } + } break; + case TASK_TYPE_NEXT_RESPONSE: { + // do nothing + } break; } + } + void process_multitask() + { // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue std::vector agg_results; auto queue_iterator = queue_multitasks.begin(); @@ -1657,18 +1431,12 @@ struct llama_server_context } } - // done with tasks, unlock - lock.unlock(); - // copy aggregate results of complete multi-tasks to the results queue std::lock_guard lock_results(mutex_results); queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end()); } bool update_slots() { - // attend tasks - process_tasks(); - if (system_need_update) { LOG_TEE("updating system prompt\n"); @@ -1684,10 +1452,12 @@ struct llama_server_context LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n"); kv_cache_clear(); } - std::unique_lock lock(mutex_tasks); - condition_tasks.wait(lock, [&]{ - return !queue_tasks.empty(); - }); + return true; + } else { + task_server task; + task.type = TASK_TYPE_NEXT_RESPONSE; + task.target_id = -1; + queue_tasks.post(task); } for (llama_client_slot &slot : slots) @@ -1997,6 +1767,11 @@ struct llama_server_context } return true; } + + void run_on_all_tasks_finished() { + process_multitask(); + update_slots(); + } }; static void server_print_usage(const char *argv0, const gpt_params ¶ms, @@ -3360,15 +3135,21 @@ int main(int argc, char **argv) // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // "Bus error: 10" - this is on macOS, it does not crash on Linux //std::thread t2([&]() - { + /*{ bool running = true; while (running) { running = llama.update_slots(); } - } + }*/ //); + llama.queue_tasks.on_new_task(std::bind( + &llama_server_context::process_single_task, &llama, std::placeholders::_1)); + llama.queue_tasks.on_all_tasks_finished(std::bind( + &llama_server_context::run_on_all_tasks_finished, &llama)); + llama.queue_tasks.start_loop(); + t.join(); llama_backend_free(); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp new file mode 100644 index 000000000..cf1925dde --- /dev/null +++ b/examples/server/utils.hpp @@ -0,0 +1,296 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "json.hpp" + +#include "../llava/clip.h" + +using json = nlohmann::json; + +#ifndef SERVER_VERBOSE +#define SERVER_VERBOSE 1 +#endif + +#if SERVER_VERBOSE != 1 +#define LOG_VERBOSE(MSG, ...) +#else +#define LOG_VERBOSE(MSG, ...) \ + do \ + { \ + if (server_verbose) \ + { \ + server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ + } \ + } while (0) +#endif + +#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) + +// +// parallel +// + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed +}; + +enum task_type { + TASK_TYPE_COMPLETION, + TASK_TYPE_CANCEL, + TASK_TYPE_NEXT_RESPONSE +}; + +struct task_server { + int id = -1; // to be filled by llama_server_queue + int target_id; + task_type type; + json data; + bool infill_mode = false; + bool embedding_mode = false; + int multitask_id = -1; +}; + +struct task_result { + int id; + int multitask_id = -1; + bool stop; + bool error; + json result_json; +}; + +struct task_multi { + int id; + std::set subtasks_remaining{}; + std::vector results{}; +}; + +// TODO: can become bool if we can't find use of more states +enum slot_state +{ + IDLE, + PROCESSING, +}; + +enum slot_command +{ + NONE, + LOAD_PROMPT, + RELEASE, +}; + +struct slot_params +{ + bool stream = true; + bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt + + uint32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_predict = -1; // new tokens to predict + + std::vector antiprompt; + + json input_prefix; + json input_suffix; +}; + +struct slot_image +{ + int32_t id; + + bool request_encode_image = false; + float * image_embedding = nullptr; + int32_t image_tokens = 0; + + clip_image_u8 * img_data; + + std::string prefix_prompt; // before of this image +}; + +// completion token output with probabilities +struct completion_token_output +{ + struct token_prob + { + llama_token tok; + float prob; + }; + + std::vector probs; + llama_token tok; + std::string text_to_send; +}; + +static inline void server_log(const char *level, const char *function, int line, + const char *message, const nlohmann::ordered_json &extra) +{ + nlohmann::ordered_json log + { + {"timestamp", time(nullptr)}, + {"level", level}, + {"function", function}, + {"line", line}, + {"message", message}, + }; + + if (!extra.empty()) + { + log.merge_patch(extra); + } + + const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); + printf("%.*s\n", (int)str.size(), str.data()); + fflush(stdout); +} + +// +// work queue utils +// + +template +struct llama_server_queue { + int id = 0; + std::mutex mutex_tasks; + std::vector queue_tasks; + std::condition_variable condition_tasks; + std::function callback_new_task; + std::function callback_all_task_finished; + + int post(T task) { + LOG_INFO("post", {}); + std::unique_lock lock(mutex_tasks); + task.id = id++; + queue_tasks.push_back(std::move(task)); + condition_tasks.notify_one(); + return task.id; + } + + int get_next_id() { + std::unique_lock lock(mutex_tasks); + return id++; + } + + void on_new_task(std::function callback) { + callback_new_task = callback; + } + + void on_all_tasks_finished(std::function callback) { + callback_all_task_finished = callback; + } + + void start_loop() { + while (true) { + // new task arrived + LOG_INFO("have new task", {}); + { + while (true) + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + task_server task = queue_tasks.front(); + queue_tasks.erase(queue_tasks.begin()); + lock.unlock(); + LOG_INFO("callback_new_task", {}); + callback_new_task(task); + } + LOG_INFO("callback_all_task_finished", {}); + callback_all_task_finished(); + } + LOG_INFO("wait for new task", {}); + // wait for new task + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return !queue_tasks.empty(); + }); + } + } + } + } +}; + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(uint8_t c) +{ + return (isalnum(c) || (c == '+') || (c == '/')); +} + +static inline std::vector base64_decode(const std::string & encoded_string) +{ + int i = 0; + int j = 0; + int in_ = 0; + + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + std::vector ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) + { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) + { + for (i = 0; i <4; i++) + { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + { + ret.push_back(char_array_3[i]); + } + i = 0; + } + } + + if (i) + { + for (j = i; j <4; j++) + { + char_array_4[j] = 0; + } + + for (j = 0; j <4; j++) + { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) + { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} \ No newline at end of file