server : Smart selection of available slot using Longest Common Substring
This commit is contained in:
parent
bde7cd3cd9
commit
1ecb6a6999
2 changed files with 177 additions and 17 deletions
|
@ -144,6 +144,7 @@ struct server_params {
|
|||
bool slots_endpoint = true;
|
||||
bool metrics_endpoint = false;
|
||||
std::string slot_save_path;
|
||||
float lcs_similarity = 0.0f;
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
|
@ -670,6 +671,9 @@ struct server_context {
|
|||
|
||||
server_metrics metrics;
|
||||
|
||||
// Longest Common Substring similarity for slot selection
|
||||
float lcs_similarity = 0.0f;
|
||||
|
||||
~server_context() {
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
|
@ -818,24 +822,88 @@ struct server_context {
|
|||
return prompt_tokens;
|
||||
}
|
||||
|
||||
server_slot * get_slot(int id) {
|
||||
int64_t t_last = ggml_time_us();
|
||||
|
||||
server_slot * last_used = nullptr;
|
||||
|
||||
server_slot * get_slot_by_id(int id) {
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.id == id && slot.available()) {
|
||||
if (slot.id == id) {
|
||||
return &slot;
|
||||
}
|
||||
|
||||
// among all available slots, find the one that has been least recently used
|
||||
if (slot.available() && slot.t_last_used < t_last) {
|
||||
last_used = &slot;
|
||||
t_last = slot.t_last_used;
|
||||
}
|
||||
}
|
||||
|
||||
return last_used;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
server_slot * get_available_slot(const std::string & prompt) {
|
||||
server_slot * ret = nullptr;
|
||||
|
||||
// find the slot that has at least n% prompt similarity
|
||||
if (ret == nullptr && lcs_similarity != 0.0f && !prompt.empty()) {
|
||||
int max_lcs_len = 0;
|
||||
float similarity = 0;
|
||||
|
||||
for (server_slot & slot : slots) {
|
||||
// skip the slot if it is not available
|
||||
if (!slot.available()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// skip the slot if it does not contains prompt
|
||||
if (!slot.prompt.is_string()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// current slot's prompt
|
||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
||||
|
||||
// length of the current slot's prompt
|
||||
int slot_prompt_len = slot_prompt.size();
|
||||
|
||||
// length of the longest common substring between the current slot's prompt and the input prompt
|
||||
int lcs_len = lcs_length(slot_prompt, prompt);
|
||||
|
||||
// fraction of the common substring length compared to the current slot's prompt length
|
||||
similarity = static_cast<float>(lcs_len) / slot_prompt_len;
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (lcs_len > max_lcs_len && similarity > lcs_similarity) {
|
||||
max_lcs_len = lcs_len;
|
||||
ret = &slot;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != nullptr) {
|
||||
LOG_VERBOSE("selected slot by lcs similarity", {
|
||||
{"id_slot", ret->id},
|
||||
{"max_lcs_len", max_lcs_len},
|
||||
{"similarity", similarity},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// find the slot that has been least recently used
|
||||
if (ret == nullptr) {
|
||||
int64_t t_last = ggml_time_us();
|
||||
for (server_slot & slot : slots) {
|
||||
// skip the slot if it is not available
|
||||
if (!slot.available()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (slot.t_last_used < t_last) {
|
||||
t_last = slot.t_last_used;
|
||||
ret = &slot;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != nullptr) {
|
||||
LOG_VERBOSE("selected slot by lru", {
|
||||
{"id_slot", ret->id},
|
||||
{"t_last", t_last},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
|
@ -1538,13 +1606,29 @@ struct server_context {
|
|||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
{
|
||||
server_slot * slot = get_slot(json_value(task.data, "id_slot", -1));
|
||||
int id_slot = json_value(task.data, "id_slot", -1);
|
||||
std::string prompt = json_value(task.data, "prompt", std::string());
|
||||
|
||||
server_slot * slot;
|
||||
|
||||
if (id_slot != -1) {
|
||||
slot = get_slot_by_id(id_slot);
|
||||
} else {
|
||||
slot = get_available_slot(prompt);
|
||||
}
|
||||
|
||||
if (slot == nullptr) {
|
||||
// if no slot is available, we defer this task for processing later
|
||||
LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
|
||||
queue_tasks.defer(task);
|
||||
break;
|
||||
}
|
||||
if (!slot->available()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||
queue_tasks.defer(task);
|
||||
break;
|
||||
}
|
||||
|
||||
if (task.data.contains("system_prompt")) {
|
||||
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
||||
|
@ -1661,11 +1745,17 @@ struct server_context {
|
|||
case SERVER_TASK_TYPE_SLOT_SAVE:
|
||||
{
|
||||
int id_slot = task.data.at("id_slot");
|
||||
server_slot * slot = get_slot(id_slot);
|
||||
server_slot * slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
}
|
||||
if (!slot->available()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||
queue_tasks.defer(task);
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t token_count = slot->cache_tokens.size();
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
@ -1696,11 +1786,17 @@ struct server_context {
|
|||
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
||||
{
|
||||
int id_slot = task.data.at("id_slot");
|
||||
server_slot * slot = get_slot(id_slot);
|
||||
server_slot * slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
}
|
||||
if (!slot->available()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||
queue_tasks.defer(task);
|
||||
break;
|
||||
}
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
|
@ -1738,11 +1834,17 @@ struct server_context {
|
|||
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||
{
|
||||
int id_slot = task.data.at("id_slot");
|
||||
server_slot * slot = get_slot(id_slot);
|
||||
server_slot * slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
}
|
||||
if (!slot->available()) {
|
||||
// if requested slot is unavailable, we defer this task for processing later
|
||||
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||
queue_tasks.defer(task);
|
||||
break;
|
||||
}
|
||||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
|
@ -2868,6 +2970,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
} else if (arg == "--lcs-similarity") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.lcs_similarity = std::stof(argv[i]);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
server_print_usage(argv[0], default_params, default_sparams);
|
||||
|
@ -3039,6 +3147,9 @@ int main(int argc, char ** argv) {
|
|||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||
}
|
||||
|
||||
// Longest Common Substring similarity for slot selection
|
||||
ctx_server.lcs_similarity = sparams.lcs_similarity;
|
||||
|
||||
// load the model
|
||||
if (!ctx_server.load_model(params)) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
|
|
|
@ -653,3 +653,52 @@ static json format_error_response(const std::string & message, const enum error_
|
|||
{"type", type_str},
|
||||
};
|
||||
}
|
||||
|
||||
static int lcs_length(const std::string & str1, const std::string & str2) {
|
||||
// check for empty strings
|
||||
if (str1.empty() || str2.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the lengths of the input strings
|
||||
int str1_len = str1.size();
|
||||
int str2_len = str2.size();
|
||||
|
||||
// initialize the maximum length of the longest common subsequence (LCS)
|
||||
int max_length = 0;
|
||||
|
||||
// use two rows instead of a 2D matrix to optimize space
|
||||
std::vector<int> prev_row(str2_len + 1, 0);
|
||||
std::vector<int> curr_row(str2_len + 1, 0);
|
||||
|
||||
// iterate through the characters of str1
|
||||
for (int i = 1; i <= str1_len; i++) {
|
||||
// iterate through the characters of str2
|
||||
for (int j = 1; j <= str2_len; j++) {
|
||||
// if characters at the current positions match
|
||||
if (str1[i - 1] == str2[j - 1]) {
|
||||
// if it's the first character of either string, set LCS length to 1
|
||||
if (i == 1 || j == 1) {
|
||||
curr_row[j] = 1;
|
||||
} else {
|
||||
// increment LCS length by 1 compared to the previous character
|
||||
curr_row[j] = prev_row[j - 1] + 1;
|
||||
}
|
||||
|
||||
// update max_length if necessary
|
||||
if (curr_row[j] > max_length) {
|
||||
max_length = curr_row[j];
|
||||
}
|
||||
} else {
|
||||
// reset LCS length if characters don't match
|
||||
curr_row[j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// update the previous row for the next iteration
|
||||
prev_row = curr_row;
|
||||
}
|
||||
|
||||
// return the maximum length of the LCS
|
||||
return max_length;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue