diff --git a/common/arg.cpp b/common/arg.cpp index 078c75384..1e02a588a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1279,6 +1279,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_parallel = value; } ).set_env("LLAMA_ARG_N_PARALLEL")); + add_opt(common_arg( + {"--aggregate", "-ag"}, + string_format("apply request aggregation (default: %s)", params.aggregate ? "enabled" : "disabled"), + [](common_params & params) { + params.aggregate = true; + } + ).set_env("LLAMA_ARG_AGGREGATION")); + add_opt(common_arg( + {"-bs", "--buffer-size"}, "N", + string_format("buffer size if aggregation is enabled (default: %d)", params.buffer_size), + [](common_params & params, int value) { + params.buffer_size = value; + } + ).set_env("LLAMA_ARG_BUFFER_SIZE")); + + add_opt(common_arg( + {"-bks", "--block-size"}, "N", + string_format("block size if aggregation is enabled and should be equal to or less than buffer_size (default: %d)", params.block_size), + [](common_params & params, int value) { + params.block_size = value; + } + ).set_env("LLAMA_ARG_BLOCK_SIZE")); add_opt(common_arg( {"-ns", "--sequences"}, "N", string_format("number of sequences to decode (default: %d)", params.n_sequences), diff --git a/common/common.h b/common/common.h index 0373fd3ea..bd5eba001 100644 --- a/common/common.h +++ b/common/common.h @@ -191,6 +191,9 @@ struct common_params { float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = 0.1f; // KV cache defragmentation threshold + bool aggregate = false; // The aggregation feature essentially groups multiple requests over a specific time period before starting to process the prompts. + int32_t buffer_size = 36; // We would wait until there are buffer_size requests or 50 ms before starting to process the requests. + int32_t block_size = 12; // We group the requests in the buffer into blocks of block_size and process them as an array of prompts, similar to how /completions does. // offload params std::vector devices; // devices to use for offloading diff --git a/examples/server/README.md b/examples/server/README.md index b2dd7b65a..3130d374f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -170,6 +170,9 @@ The project is under active development, and we are [looking for feedback and co | `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model | | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused) | +| `-ag, --aggregate` | to enable request aggregation | +| `-bs, --buffer-size N` | to specify buffer size of the aggregation | +| `-bks,--block-size N` | to specify the block size (array size) of requests processed together when aggregation is enabled; it should be less than the buffer size. | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9bca3f30e..dd0f577ba 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -32,6 +32,11 @@ using json = nlohmann::ordered_json; +using input_buffer_element = std::tuple , int>; // a single input element representation +std::shared_ptr> input_buffer = std::make_shared>(); // input buffer +std::atomic request_id(0); // request id counter, always increasing +std::shared_ptr> output_buffer = std::make_shared>(); // output buffer, only needs to hold the request id to know if the response is ready + enum stop_type { STOP_TYPE_FULL, STOP_TYPE_PARTIAL, @@ -2417,10 +2422,128 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } + +void process_aggregate_prompts( + const std::function> & res, std::vector ids)> + handle_completions_generic_aggregate, + const std::function handle_completions_generic, + std::mutex & input_buffer_mutex, std::mutex & output_buffer_mutex, int buffer_size, int block_size, int n_predict) { + while (true) { + int i = 0; + // Wait for the buffer to fill up or for 50ms + while (input_buffer->size() <= buffer_size && i < 5) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + i++; + } + + std::vector prompts_to_process; + + if (!input_buffer->empty()) { + std::unique_lock lock(input_buffer_mutex); + prompts_to_process = std::move(*input_buffer); // Move prompts out of buffer + input_buffer->clear(); // Clear buffer after moving + lock.unlock(); + } + + if (!prompts_to_process.empty()) { + if (prompts_to_process.size() > block_size) { + // sort the prompts by length to create a more uniform distribution of prompt lengths for batching + // and also to reduce the average response time + std::sort(prompts_to_process.begin(), prompts_to_process.end(), + [](const input_buffer_element & a, const input_buffer_element & b) { + const std::string & prompt_a = std::get<0>(a)["prompt"]; + const std::string & prompt_b = std::get<0>(b)["prompt"]; + return prompt_a.length() < prompt_b.length(); + }); + } + + if (block_size == 1) { + // if block_size is 1, we process each prompt individually no nead to create blocks + std::vector> futures; + for (int k = 0; k < prompts_to_process.size(); ++k) { + json & json_data = std::get<0>((prompts_to_process)[k]); + httplib::Response & response_object = (std::get<1>((prompts_to_process)[k])).get(); + int id = std::get<2>((prompts_to_process)[k]); + + // we do not want for the completion to be processed in the same thread + auto task = + std::async(std::launch::async, handle_completions_generic, SERVER_TASK_INF_TYPE_COMPLETION, + std::ref(json_data), std::ref(response_object), id); + + futures.push_back(std::move(task)); + } + for (auto & future : futures) { + future.wait(); + } + + } + + else { + std::vector> futures; + for (int k = 0; k < prompts_to_process.size(); k += block_size) { + nlohmann::json prompt_array = nlohmann::json::array(); + std::vector> response_objects; + std::vector ids; + + for (int j = 0; j < block_size; j++) { + if ((k + j) == prompts_to_process.size()) { + break; + } + // concatinate the prompts into a single array + json & json_data = std::get<0>((prompts_to_process)[k + j]); + std::string prompt = json_data["prompt"].get(); + if (prompt == "") { + continue; + } + int id = std::get<2>((prompts_to_process)[k + j]); + response_objects.emplace_back(std::ref((std::get<1>((prompts_to_process)[k + j])).get())); + + // Add the prompt to the array + prompt_array.push_back(prompt); + ids.push_back(id); + } + + if (prompt_array.empty()) { + // Handle empty array case + continue; + } + json json_result; + json_result["prompt"] = prompt_array; + // since multiple prompts are being processed, and the we need a common n_predict for all prompts + // we can not use the n_predict from the prompt json itself, either we use the n_predict from the params or the default value + if (n_predict == -1) { + n_predict = 50; + } + json_result["n_predict"] = n_predict; + + // to take advantage of multi-threading, we process the completions in parallel + auto task = std::async(std::launch::async, [&handle_completions_generic_aggregate, + json_result = std::move(json_result), &response_objects, + const_ids = ids]() mutable { + handle_completions_generic_aggregate(SERVER_TASK_INF_TYPE_COMPLETION, json_result, + response_objects, const_ids); + }); + + futures.push_back(std::move(task)); + } + + for (auto & future : futures) { + future.wait(); + } + } + } + } +}; + + int main(int argc, char ** argv) { // own arguments required by this example common_params params; + std::mutex input_buffer_mutex; // mutex for input buffer + std::mutex output_buffer_mutex; // mutex for output buffer + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { return 1; } @@ -2880,9 +3003,13 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok, &output_buffer_mutex]( + server_task_inf_type inf_type, json & data, httplib::Response & res, + int id) { if (ctx_server.params_base.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, + format_error_response("This server does not support completions. Start it without `--embeddings`", + ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2890,51 +3017,183 @@ int main(int argc, char ** argv) { ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); - bool stream = json_value(data, "stream", false); + bool stream = json_value(data, "stream", false); const auto task_ids = server_task::get_list_id(tasks); if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - if (results.size() == 1) { - // single result - res_ok(res, results[0].data); - } else { - // multiple results (multitask) - json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); + ctx_server.receive_cmpl_results( + task_ids, + [&](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, results[0].data); + + // mark the request in the output buffer as ready + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[id] = true; + lock.unlock(); + + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); + } + + res_ok(res, arr); + // mark the request in the output buffer as ready + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[id] = true; + lock.unlock(); } - res_ok(res, arr); - } - }, [&](const json & error_data) { - res_error(res, error_data); - }); + }, + [&](const json & error_data) { res_error(res, error_data); }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - return server_sent_event(sink, "data", result.data); - }, [&](const json & error_data) { - server_sent_event(sink, "error", error_data); - }); + ctx_server.receive_cmpl_results_stream( + task_ids, + [&](const server_task_result & result) -> bool { + return server_sent_event(sink, "data", result.data); + }, + [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); sink.done(); return false; }; - auto on_complete = [task_ids, &ctx_server] (bool) { + auto on_complete = [task_ids, &ctx_server](bool) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[id] = true; + lock.unlock(); } }; - const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); - }; + const auto handle_completions_generic_aggregate = + [&ctx_server, &res_error, &res_ok, &output_buffer_mutex]( + server_task_inf_type inf_type, json & data, + const std::vector> & res, std::vector ids) { + if (ctx_server.params.embedding) { + res_error(res[0].get(), format_error_response( + "This server does not support completions. Start it without `--embeddings`", + ERROR_TYPE_NOT_SUPPORTED)); + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[ids[0]] = true; + lock.unlock(); + + return; + } + + std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); + + if (!stream) { + ctx_server.receive_cmpl_results( + task_ids, + [&](std::vector & results) { + for (int i = 0; i < results.size(); i++) { + res_ok(res[i].get(), results[i].data); + // mark the request in the output buffer as ready + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[ids[i]] = true; + lock.unlock(); + } + }, + [&](const json & error_data) { + res_error(res[0].get(), error_data); + // mark the request in the output buffer as ready + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[ids[0]] = true; + lock.unlock(); + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + + } else { + const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream( + task_ids, + [&](const server_task_result & result) -> bool { + return server_sent_event(sink, "data", result.data); + }, + [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); + sink.done(); + return false; + }; + + auto on_complete = [task_ids, &ctx_server](bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res[0].get().set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + + std::unique_lock lock(output_buffer_mutex); + (*output_buffer)[ids[0]] = true; + lock.unlock(); + } + }; + + const auto handle_completions = [¶ms, &handle_completions_generic, &handle_completions_generic_aggregate, + &input_buffer_mutex, + &output_buffer_mutex](const httplib::Request & req, httplib::Response & res) { + std::chrono::time_point start = std::chrono::high_resolution_clock::now(); + + json data = json::parse(req.body); + + if (!params.aggregate) { + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, -1); + } + + // get the request id, each request has a unique id + ++request_id; + int id = request_id.load(); + + input_buffer_element new_element = { data, std::ref(res), id }; + + // mark the request in the output buffer as not ready + std::unique_lock lock2(output_buffer_mutex); + (*output_buffer)[id] = false; + lock2.unlock(); + + // add the request to the input buffer + std::unique_lock lock(input_buffer_mutex); + input_buffer->emplace_back(new_element); + lock.unlock(); + + while (true) { + if ((*output_buffer)[id]) { + std::chrono::time_point end = + std::chrono::high_resolution_clock::now(); + std::chrono::duration duration = + std::chrono::duration_cast(end - start); + + nlohmann::json json_data = nlohmann::json::parse(res.body); + // for performance metrics, can be omitted + json_data["duration"] = duration.count(); + + // Serialize the updated JSON object back to a string + std::string updated_content = json_data.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace); + + // Update the response content with the defined MIME type + res.set_content(updated_content, MIMETYPE_JSON); + return; + } + + //Sleep to conserve CPU resource instead of busy waiting. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; @@ -2982,7 +3241,7 @@ int main(int argc, char ** argv) { } data["input_extra"] = input_extra; // default to empty array if it's not exist - return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); + return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res,-1); }; // TODO: maybe merge this function with "handle_completions_generic" @@ -3413,7 +3672,25 @@ int main(int argc, char ** argv) { LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port); - ctx_server.queue_tasks.start_loop(); + if (params.aggregate) { + std::thread batch_processing_thread( + process_aggregate_prompts, + std::ref(handle_completions_generic_aggregate), + std::ref(handle_completions_generic), + std::ref(input_buffer_mutex), + std::ref(output_buffer_mutex), + params.buffer_size, + params.block_size, + params.n_predict + ); + + ctx_server.queue_tasks.start_loop(); + batch_processing_thread.join(); + } + + else{ + ctx_server.queue_tasks.start_loop(); + } #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action;