This commit is contained in:
Kalab Yibeltal Assefa 2024-12-06 08:24:54 +05:30 committed by GitHub
commit 52e281e78e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 334 additions and 29 deletions

View file

@ -1279,6 +1279,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_parallel = value;
}
).set_env("LLAMA_ARG_N_PARALLEL"));
add_opt(common_arg(
{"--aggregate", "-ag"},
string_format("apply request aggregation (default: %s)", params.aggregate ? "enabled" : "disabled"),
[](common_params & params) {
params.aggregate = true;
}
).set_env("LLAMA_ARG_AGGREGATION"));
add_opt(common_arg(
{"-bs", "--buffer-size"}, "N",
string_format("buffer size if aggregation is enabled (default: %d)", params.buffer_size),
[](common_params & params, int value) {
params.buffer_size = value;
}
).set_env("LLAMA_ARG_BUFFER_SIZE"));
add_opt(common_arg(
{"-bks", "--block-size"}, "N",
string_format("block size if aggregation is enabled and should be equal to or less than buffer_size (default: %d)", params.block_size),
[](common_params & params, int value) {
params.block_size = value;
}
).set_env("LLAMA_ARG_BLOCK_SIZE"));
add_opt(common_arg(
{"-ns", "--sequences"}, "N",
string_format("number of sequences to decode (default: %d)", params.n_sequences),

View file

@ -191,6 +191,9 @@ struct common_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = 0.1f; // KV cache defragmentation threshold
bool aggregate = false; // The aggregation feature essentially groups multiple requests over a specific time period before starting to process the prompts.
int32_t buffer_size = 36; // We would wait until there are buffer_size requests or 50 ms before starting to process the requests.
int32_t block_size = 12; // We group the requests in the buffer into blocks of block_size and process them as an array of prompts, similar to how /completions does.
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading

View file

@ -170,6 +170,9 @@ The project is under active development, and we are [looking for feedback and co
| `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model |
| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused) |
| `-ag, --aggregate` | to enable request aggregation |
| `-bs, --buffer-size N` | to specify buffer size of the aggregation |
| `-bks,--block-size N` | to specify the block size (array size) of requests processed together when aggregation is enabled; it should be less than the buffer size. |
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.

View file

@ -32,6 +32,11 @@
using json = nlohmann::ordered_json;
using input_buffer_element = std::tuple<json, std::reference_wrapper<httplib::Response> , int>; // a single input element representation
std::shared_ptr<std::vector< input_buffer_element >> input_buffer = std::make_shared<std::vector< input_buffer_element >>(); // input buffer
std::atomic<int> request_id(0); // request id counter, always increasing
std::shared_ptr<std::map<int, bool>> output_buffer = std::make_shared<std::map<int, bool>>(); // output buffer, only needs to hold the request id to know if the response is ready
enum stop_type {
STOP_TYPE_FULL,
STOP_TYPE_PARTIAL,
@ -2442,10 +2447,128 @@ inline void signal_handler(int signal) {
shutdown_handler(signal);
}
void process_aggregate_prompts(
const std::function<void(server_task_inf_type, json &,
const std::vector<std::reference_wrapper<httplib::Response>> & res, std::vector<int> ids)>
handle_completions_generic_aggregate,
const std::function<void(server_task_inf_type, json &, httplib::Response &, int)> handle_completions_generic,
std::mutex & input_buffer_mutex, std::mutex & output_buffer_mutex, int buffer_size, int block_size, int n_predict) {
while (true) {
int i = 0;
// Wait for the buffer to fill up or for 50ms
while (input_buffer->size() <= buffer_size && i < 5) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
i++;
}
std::vector<input_buffer_element> prompts_to_process;
if (!input_buffer->empty()) {
std::unique_lock<std::mutex> lock(input_buffer_mutex);
prompts_to_process = std::move(*input_buffer); // Move prompts out of buffer
input_buffer->clear(); // Clear buffer after moving
lock.unlock();
}
if (!prompts_to_process.empty()) {
if (prompts_to_process.size() > block_size) {
// sort the prompts by length to create a more uniform distribution of prompt lengths for batching
// and also to reduce the average response time
std::sort(prompts_to_process.begin(), prompts_to_process.end(),
[](const input_buffer_element & a, const input_buffer_element & b) {
const std::string & prompt_a = std::get<0>(a)["prompt"];
const std::string & prompt_b = std::get<0>(b)["prompt"];
return prompt_a.length() < prompt_b.length();
});
}
if (block_size == 1) {
// if block_size is 1, we process each prompt individually no nead to create blocks
std::vector<std::future<void>> futures;
for (int k = 0; k < prompts_to_process.size(); ++k) {
json & json_data = std::get<0>((prompts_to_process)[k]);
httplib::Response & response_object = (std::get<1>((prompts_to_process)[k])).get();
int id = std::get<2>((prompts_to_process)[k]);
// we do not want for the completion to be processed in the same thread
auto task =
std::async(std::launch::async, handle_completions_generic, SERVER_TASK_INF_TYPE_COMPLETION,
std::ref(json_data), std::ref(response_object), id);
futures.push_back(std::move(task));
}
for (auto & future : futures) {
future.wait();
}
}
else {
std::vector<std::future<void>> futures;
for (int k = 0; k < prompts_to_process.size(); k += block_size) {
nlohmann::json prompt_array = nlohmann::json::array();
std::vector<std::reference_wrapper<httplib::Response>> response_objects;
std::vector<int> ids;
for (int j = 0; j < block_size; j++) {
if ((k + j) == prompts_to_process.size()) {
break;
}
// concatinate the prompts into a single array
json & json_data = std::get<0>((prompts_to_process)[k + j]);
std::string prompt = json_data["prompt"].get<std::string>();
if (prompt == "") {
continue;
}
int id = std::get<2>((prompts_to_process)[k + j]);
response_objects.emplace_back(std::ref((std::get<1>((prompts_to_process)[k + j])).get()));
// Add the prompt to the array
prompt_array.push_back(prompt);
ids.push_back(id);
}
if (prompt_array.empty()) {
// Handle empty array case
continue;
}
json json_result;
json_result["prompt"] = prompt_array;
// since multiple prompts are being processed, and the we need a common n_predict for all prompts
// we can not use the n_predict from the prompt json itself, either we use the n_predict from the params or the default value
if (n_predict == -1) {
n_predict = 50;
}
json_result["n_predict"] = n_predict;
// to take advantage of multi-threading, we process the completions in parallel
auto task = std::async(std::launch::async, [&handle_completions_generic_aggregate,
json_result = std::move(json_result), &response_objects,
const_ids = ids]() mutable {
handle_completions_generic_aggregate(SERVER_TASK_INF_TYPE_COMPLETION, json_result,
response_objects, const_ids);
});
futures.push_back(std::move(task));
}
for (auto & future : futures) {
future.wait();
}
}
}
}
};
int main(int argc, char ** argv) {
// own arguments required by this example
common_params params;
std::mutex input_buffer_mutex; // mutex for input buffer
std::mutex output_buffer_mutex; // mutex for output buffer
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
return 1;
}
@ -2905,9 +3028,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok, &output_buffer_mutex](
server_task_inf_type inf_type, json & data, httplib::Response & res,
int id) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
res_error(res,
format_error_response("This server does not support completions. Start it without `--embeddings`",
ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -2915,51 +3042,183 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
ctx_server.receive_cmpl_results(
task_ids,
[&](std::vector<server_task_result> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
// mark the request in the output buffer as ready
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[id] = true;
lock.unlock();
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
res_ok(res, arr);
// mark the request in the output buffer as ready
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[id] = true;
lock.unlock();
}
res_ok(res, arr);
}
}, [&](const json & error_data) {
res_error(res, error_data);
});
},
[&](const json & error_data) { res_error(res, error_data); });
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data);
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
ctx_server.receive_cmpl_results_stream(
task_ids,
[&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data);
},
[&](const json & error_data) { server_sent_event(sink, "error", error_data); });
sink.done();
return false;
};
auto on_complete = [task_ids, &ctx_server] (bool) {
auto on_complete = [task_ids, &ctx_server](bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[id] = true;
lock.unlock();
}
};
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
};
const auto handle_completions_generic_aggregate =
[&ctx_server, &res_error, &res_ok, &output_buffer_mutex](
server_task_inf_type inf_type, json & data,
const std::vector<std::reference_wrapper<httplib::Response>> & res, std::vector<int> ids) {
if (ctx_server.params.embedding) {
res_error(res[0].get(), format_error_response(
"This server does not support completions. Start it without `--embeddings`",
ERROR_TYPE_NOT_SUPPORTED));
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[ids[0]] = true;
lock.unlock();
return;
}
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
if (!stream) {
ctx_server.receive_cmpl_results(
task_ids,
[&](std::vector<server_task_result> & results) {
for (int i = 0; i < results.size(); i++) {
res_ok(res[i].get(), results[i].data);
// mark the request in the output buffer as ready
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[ids[i]] = true;
lock.unlock();
}
},
[&](const json & error_data) {
res_error(res[0].get(), error_data);
// mark the request in the output buffer as ready
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[ids[0]] = true;
lock.unlock();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(
task_ids,
[&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data);
},
[&](const json & error_data) { server_sent_event(sink, "error", error_data); });
sink.done();
return false;
};
auto on_complete = [task_ids, &ctx_server](bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res[0].get().set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
std::unique_lock<std::mutex> lock(output_buffer_mutex);
(*output_buffer)[ids[0]] = true;
lock.unlock();
}
};
const auto handle_completions = [&params, &handle_completions_generic, &handle_completions_generic_aggregate,
&input_buffer_mutex,
&output_buffer_mutex](const httplib::Request & req, httplib::Response & res) {
std::chrono::time_point<std::chrono::high_resolution_clock> start = std::chrono::high_resolution_clock::now();
json data = json::parse(req.body);
if (!params.aggregate) {
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, -1);
}
// get the request id, each request has a unique id
++request_id;
int id = request_id.load();
input_buffer_element new_element = { data, std::ref(res), id };
// mark the request in the output buffer as not ready
std::unique_lock<std::mutex> lock2(output_buffer_mutex);
(*output_buffer)[id] = false;
lock2.unlock();
// add the request to the input buffer
std::unique_lock<std::mutex> lock(input_buffer_mutex);
input_buffer->emplace_back(new_element);
lock.unlock();
while (true) {
if ((*output_buffer)[id]) {
std::chrono::time_point<std::chrono::high_resolution_clock> end =
std::chrono::high_resolution_clock::now();
std::chrono::duration<long long, std::micro> duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
nlohmann::json json_data = nlohmann::json::parse(res.body);
// for performance metrics, can be omitted
json_data["duration"] = duration.count();
// Serialize the updated JSON object back to a string
std::string updated_content = json_data.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace);
// Update the response content with the defined MIME type
res.set_content(updated_content, MIMETYPE_JSON);
return;
}
//Sleep to conserve CPU resource instead of busy waiting.
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
@ -3007,7 +3266,7 @@ int main(int argc, char ** argv) {
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res,-1);
};
// TODO: maybe merge this function with "handle_completions_generic"
@ -3438,7 +3697,25 @@ int main(int argc, char ** argv) {
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();
if (params.aggregate) {
std::thread batch_processing_thread(
process_aggregate_prompts,
std::ref(handle_completions_generic_aggregate),
std::ref(handle_completions_generic),
std::ref(input_buffer_mutex),
std::ref(output_buffer_mutex),
params.buffer_size,
params.block_size,
params.n_predict
);
ctx_server.queue_tasks.start_loop();
batch_processing_thread.join();
}
else{
ctx_server.queue_tasks.start_loop();
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;