add request aggregation functionality
This commit is contained in:
parent
59f4db1088
commit
fb93f70533
4 changed files with 334 additions and 29 deletions
|
@ -1279,6 +1279,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.n_parallel = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_N_PARALLEL"));
|
||||
add_opt(common_arg(
|
||||
{"--aggregate", "-ag"},
|
||||
string_format("apply request aggregation (default: %s)", params.aggregate ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.aggregate = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_AGGREGATION"));
|
||||
add_opt(common_arg(
|
||||
{"-bs", "--buffer-size"}, "N",
|
||||
string_format("buffer size if aggregation is enabled (default: %d)", params.buffer_size),
|
||||
[](common_params & params, int value) {
|
||||
params.buffer_size = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_BUFFER_SIZE"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"-bks", "--block-size"}, "N",
|
||||
string_format("block size if aggregation is enabled and should be equal to or less than buffer_size (default: %d)", params.block_size),
|
||||
[](common_params & params, int value) {
|
||||
params.block_size = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_BLOCK_SIZE"));
|
||||
add_opt(common_arg(
|
||||
{"-ns", "--sequences"}, "N",
|
||||
string_format("number of sequences to decode (default: %d)", params.n_sequences),
|
||||
|
|
|
@ -191,6 +191,9 @@ struct common_params {
|
|||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
||||
bool aggregate = false; // The aggregation feature essentially groups multiple requests over a specific time period before starting to process the prompts.
|
||||
int32_t buffer_size = 36; // We would wait until there are buffer_size requests or 50 ms before starting to process the requests.
|
||||
int32_t block_size = 12; // We group the requests in the buffer into blocks of block_size and process them as an array of prompts, similar to how /completions does.
|
||||
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
|
|
@ -170,6 +170,9 @@ The project is under active development, and we are [looking for feedback and co
|
|||
| `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model |
|
||||
| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused) |
|
||||
| `-ag, --aggregate` | to enable request aggregation |
|
||||
| `-bs, --buffer-size N` | to specify buffer size of the aggregation |
|
||||
| `-bks,--block-size N` | to specify the block size (array size) of requests processed together when aggregation is enabled; it should be less than the buffer size. |
|
||||
|
||||
|
||||
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
|
||||
|
|
|
@ -32,6 +32,11 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
using input_buffer_element = std::tuple<json, std::reference_wrapper<httplib::Response> , int>; // a single input element representation
|
||||
std::shared_ptr<std::vector< input_buffer_element >> input_buffer = std::make_shared<std::vector< input_buffer_element >>(); // input buffer
|
||||
std::atomic<int> request_id(0); // request id counter, always increasing
|
||||
std::shared_ptr<std::map<int, bool>> output_buffer = std::make_shared<std::map<int, bool>>(); // output buffer, only needs to hold the request id to know if the response is ready
|
||||
|
||||
enum stop_type {
|
||||
STOP_TYPE_FULL,
|
||||
STOP_TYPE_PARTIAL,
|
||||
|
@ -2417,10 +2422,128 @@ inline void signal_handler(int signal) {
|
|||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
|
||||
void process_aggregate_prompts(
|
||||
const std::function<void(server_task_inf_type, json &,
|
||||
const std::vector<std::reference_wrapper<httplib::Response>> & res, std::vector<int> ids)>
|
||||
handle_completions_generic_aggregate,
|
||||
const std::function<void(server_task_inf_type, json &, httplib::Response &, int)> handle_completions_generic,
|
||||
std::mutex & input_buffer_mutex, std::mutex & output_buffer_mutex, int buffer_size, int block_size, int n_predict) {
|
||||
while (true) {
|
||||
int i = 0;
|
||||
// Wait for the buffer to fill up or for 50ms
|
||||
while (input_buffer->size() <= buffer_size && i < 5) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
i++;
|
||||
}
|
||||
|
||||
std::vector<input_buffer_element> prompts_to_process;
|
||||
|
||||
if (!input_buffer->empty()) {
|
||||
std::unique_lock<std::mutex> lock(input_buffer_mutex);
|
||||
prompts_to_process = std::move(*input_buffer); // Move prompts out of buffer
|
||||
input_buffer->clear(); // Clear buffer after moving
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
if (!prompts_to_process.empty()) {
|
||||
if (prompts_to_process.size() > block_size) {
|
||||
// sort the prompts by length to create a more uniform distribution of prompt lengths for batching
|
||||
// and also to reduce the average response time
|
||||
std::sort(prompts_to_process.begin(), prompts_to_process.end(),
|
||||
[](const input_buffer_element & a, const input_buffer_element & b) {
|
||||
const std::string & prompt_a = std::get<0>(a)["prompt"];
|
||||
const std::string & prompt_b = std::get<0>(b)["prompt"];
|
||||
return prompt_a.length() < prompt_b.length();
|
||||
});
|
||||
}
|
||||
|
||||
if (block_size == 1) {
|
||||
// if block_size is 1, we process each prompt individually no nead to create blocks
|
||||
std::vector<std::future<void>> futures;
|
||||
for (int k = 0; k < prompts_to_process.size(); ++k) {
|
||||
json & json_data = std::get<0>((prompts_to_process)[k]);
|
||||
httplib::Response & response_object = (std::get<1>((prompts_to_process)[k])).get();
|
||||
int id = std::get<2>((prompts_to_process)[k]);
|
||||
|
||||
// we do not want for the completion to be processed in the same thread
|
||||
auto task =
|
||||
std::async(std::launch::async, handle_completions_generic, SERVER_TASK_INF_TYPE_COMPLETION,
|
||||
std::ref(json_data), std::ref(response_object), id);
|
||||
|
||||
futures.push_back(std::move(task));
|
||||
}
|
||||
for (auto & future : futures) {
|
||||
future.wait();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
else {
|
||||
std::vector<std::future<void>> futures;
|
||||
for (int k = 0; k < prompts_to_process.size(); k += block_size) {
|
||||
nlohmann::json prompt_array = nlohmann::json::array();
|
||||
std::vector<std::reference_wrapper<httplib::Response>> response_objects;
|
||||
std::vector<int> ids;
|
||||
|
||||
for (int j = 0; j < block_size; j++) {
|
||||
if ((k + j) == prompts_to_process.size()) {
|
||||
break;
|
||||
}
|
||||
// concatinate the prompts into a single array
|
||||
json & json_data = std::get<0>((prompts_to_process)[k + j]);
|
||||
std::string prompt = json_data["prompt"].get<std::string>();
|
||||
if (prompt == "") {
|
||||
continue;
|
||||
}
|
||||
int id = std::get<2>((prompts_to_process)[k + j]);
|
||||
response_objects.emplace_back(std::ref((std::get<1>((prompts_to_process)[k + j])).get()));
|
||||
|
||||
// Add the prompt to the array
|
||||
prompt_array.push_back(prompt);
|
||||
ids.push_back(id);
|
||||
}
|
||||
|
||||
if (prompt_array.empty()) {
|
||||
// Handle empty array case
|
||||
continue;
|
||||
}
|
||||
json json_result;
|
||||
json_result["prompt"] = prompt_array;
|
||||
// since multiple prompts are being processed, and the we need a common n_predict for all prompts
|
||||
// we can not use the n_predict from the prompt json itself, either we use the n_predict from the params or the default value
|
||||
if (n_predict == -1) {
|
||||
n_predict = 50;
|
||||
}
|
||||
json_result["n_predict"] = n_predict;
|
||||
|
||||
// to take advantage of multi-threading, we process the completions in parallel
|
||||
auto task = std::async(std::launch::async, [&handle_completions_generic_aggregate,
|
||||
json_result = std::move(json_result), &response_objects,
|
||||
const_ids = ids]() mutable {
|
||||
handle_completions_generic_aggregate(SERVER_TASK_INF_TYPE_COMPLETION, json_result,
|
||||
response_objects, const_ids);
|
||||
});
|
||||
|
||||
futures.push_back(std::move(task));
|
||||
}
|
||||
|
||||
for (auto & future : futures) {
|
||||
future.wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
// own arguments required by this example
|
||||
common_params params;
|
||||
|
||||
std::mutex input_buffer_mutex; // mutex for input buffer
|
||||
std::mutex output_buffer_mutex; // mutex for output buffer
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
||||
return 1;
|
||||
}
|
||||
|
@ -2880,9 +3003,13 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok, &output_buffer_mutex](
|
||||
server_task_inf_type inf_type, json & data, httplib::Response & res,
|
||||
int id) {
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res,
|
||||
format_error_response("This server does not support completions. Start it without `--embeddings`",
|
||||
ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2890,51 +3017,183 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
bool stream = json_value(data, "stream", false);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, results[0].data);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
ctx_server.receive_cmpl_results(
|
||||
task_ids,
|
||||
[&](std::vector<server_task_result> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, results[0].data);
|
||||
|
||||
// mark the request in the output buffer as ready
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[id] = true;
|
||||
lock.unlock();
|
||||
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
}
|
||||
|
||||
res_ok(res, arr);
|
||||
// mark the request in the output buffer as ready
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[id] = true;
|
||||
lock.unlock();
|
||||
}
|
||||
res_ok(res, arr);
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
},
|
||||
[&](const json & error_data) { res_error(res, error_data); });
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
ctx_server.receive_cmpl_results_stream(
|
||||
task_ids,
|
||||
[&](const server_task_result & result) -> bool {
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
},
|
||||
[&](const json & error_data) { server_sent_event(sink, "error", error_data); });
|
||||
sink.done();
|
||||
return false;
|
||||
};
|
||||
|
||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
||||
auto on_complete = [task_ids, &ctx_server](bool) {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[id] = true;
|
||||
lock.unlock();
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||
};
|
||||
const auto handle_completions_generic_aggregate =
|
||||
[&ctx_server, &res_error, &res_ok, &output_buffer_mutex](
|
||||
server_task_inf_type inf_type, json & data,
|
||||
const std::vector<std::reference_wrapper<httplib::Response>> & res, std::vector<int> ids) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res[0].get(), format_error_response(
|
||||
"This server does not support completions. Start it without `--embeddings`",
|
||||
ERROR_TYPE_NOT_SUPPORTED));
|
||||
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[ids[0]] = true;
|
||||
lock.unlock();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(
|
||||
task_ids,
|
||||
[&](std::vector<server_task_result> & results) {
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
res_ok(res[i].get(), results[i].data);
|
||||
// mark the request in the output buffer as ready
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[ids[i]] = true;
|
||||
lock.unlock();
|
||||
}
|
||||
},
|
||||
[&](const json & error_data) {
|
||||
res_error(res[0].get(), error_data);
|
||||
// mark the request in the output buffer as ready
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[ids[0]] = true;
|
||||
lock.unlock();
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(
|
||||
task_ids,
|
||||
[&](const server_task_result & result) -> bool {
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
},
|
||||
[&](const json & error_data) { server_sent_event(sink, "error", error_data); });
|
||||
sink.done();
|
||||
return false;
|
||||
};
|
||||
|
||||
auto on_complete = [task_ids, &ctx_server](bool) {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
};
|
||||
|
||||
res[0].get().set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
|
||||
std::unique_lock<std::mutex> lock(output_buffer_mutex);
|
||||
(*output_buffer)[ids[0]] = true;
|
||||
lock.unlock();
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_completions = [¶ms, &handle_completions_generic, &handle_completions_generic_aggregate,
|
||||
&input_buffer_mutex,
|
||||
&output_buffer_mutex](const httplib::Request & req, httplib::Response & res) {
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
if (!params.aggregate) {
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, -1);
|
||||
}
|
||||
|
||||
// get the request id, each request has a unique id
|
||||
++request_id;
|
||||
int id = request_id.load();
|
||||
|
||||
input_buffer_element new_element = { data, std::ref(res), id };
|
||||
|
||||
// mark the request in the output buffer as not ready
|
||||
std::unique_lock<std::mutex> lock2(output_buffer_mutex);
|
||||
(*output_buffer)[id] = false;
|
||||
lock2.unlock();
|
||||
|
||||
// add the request to the input buffer
|
||||
std::unique_lock<std::mutex> lock(input_buffer_mutex);
|
||||
input_buffer->emplace_back(new_element);
|
||||
lock.unlock();
|
||||
|
||||
while (true) {
|
||||
if ((*output_buffer)[id]) {
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> end =
|
||||
std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<long long, std::micro> duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||
|
||||
nlohmann::json json_data = nlohmann::json::parse(res.body);
|
||||
// for performance metrics, can be omitted
|
||||
json_data["duration"] = duration.count();
|
||||
|
||||
// Serialize the updated JSON object back to a string
|
||||
std::string updated_content = json_data.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace);
|
||||
|
||||
// Update the response content with the defined MIME type
|
||||
res.set_content(updated_content, MIMETYPE_JSON);
|
||||
return;
|
||||
}
|
||||
|
||||
//Sleep to conserve CPU resource instead of busy waiting.
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
};
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
// check model compatibility
|
||||
std::string err;
|
||||
|
@ -2982,7 +3241,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res,-1);
|
||||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
|
@ -3413,7 +3672,25 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
if (params.aggregate) {
|
||||
std::thread batch_processing_thread(
|
||||
process_aggregate_prompts,
|
||||
std::ref(handle_completions_generic_aggregate),
|
||||
std::ref(handle_completions_generic),
|
||||
std::ref(input_buffer_mutex),
|
||||
std::ref(output_buffer_mutex),
|
||||
params.buffer_size,
|
||||
params.block_size,
|
||||
params.n_predict
|
||||
);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
batch_processing_thread.join();
|
||||
}
|
||||
|
||||
else{
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue