Use Longest Common Prefix (LCP) instead of LCS
This commit is contained in:
parent
f1164112de
commit
36083dca2c
4 changed files with 25 additions and 67 deletions
|
@ -1460,12 +1460,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.chat_template = argv[i];
|
params.chat_template = argv[i];
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--lcs-similarity") {
|
if (arg == "--lcp-similarity") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
params.lcs_similarity = std::stof(argv[i]);
|
params.lcp_similarity = std::stof(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "-pps") {
|
if (arg == "-pps") {
|
||||||
|
@ -1839,8 +1839,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||||
"only commonly used templates are accepted:\n"
|
"only commonly used templates are accepted:\n"
|
||||||
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
||||||
options.push_back({ "server", " --lcs-similarity SIMILARITY",
|
options.push_back({ "server", " --lcp-similarity SIMILARITY",
|
||||||
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcs_similarity });
|
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcp_similarity });
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
options.push_back({ "logging" });
|
options.push_back({ "logging" });
|
||||||
|
|
|
@ -202,7 +202,7 @@ struct gpt_params {
|
||||||
|
|
||||||
std::string slot_save_path;
|
std::string slot_save_path;
|
||||||
|
|
||||||
float lcs_similarity = 0.0f;
|
float lcp_similarity = 0.0f;
|
||||||
|
|
||||||
// batched-bench params
|
// batched-bench params
|
||||||
bool is_pp_shared = false;
|
bool is_pp_shared = false;
|
||||||
|
|
|
@ -647,8 +647,8 @@ struct server_context {
|
||||||
|
|
||||||
server_metrics metrics;
|
server_metrics metrics;
|
||||||
|
|
||||||
// Longest Common Substring similarity for slot selection
|
// Longest Common Prefix similarity for slot selection
|
||||||
float lcs_similarity = 0.0f;
|
float lcp_similarity = 0.0f;
|
||||||
|
|
||||||
~server_context() {
|
~server_context() {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
|
@ -812,8 +812,8 @@ struct server_context {
|
||||||
server_slot * ret = nullptr;
|
server_slot * ret = nullptr;
|
||||||
|
|
||||||
// find the slot that has at least n% prompt similarity
|
// find the slot that has at least n% prompt similarity
|
||||||
if (ret == nullptr && lcs_similarity != 0.0f && !prompt.empty()) {
|
if (ret == nullptr && lcp_similarity != 0.0f && !prompt.empty()) {
|
||||||
int max_lcs_len = 0;
|
int max_lcp_len = 0;
|
||||||
float similarity = 0;
|
float similarity = 0;
|
||||||
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
|
@ -833,23 +833,23 @@ struct server_context {
|
||||||
// length of the current slot's prompt
|
// length of the current slot's prompt
|
||||||
int slot_prompt_len = slot_prompt.size();
|
int slot_prompt_len = slot_prompt.size();
|
||||||
|
|
||||||
// length of the longest common substring between the current slot's prompt and the input prompt
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||||
int lcs_len = lcs_length(slot_prompt, prompt);
|
int lcp_len = common_part(slot_prompt, prompt);
|
||||||
|
|
||||||
// fraction of the common substring length compared to the current slot's prompt length
|
// fraction of the common substring length compared to the current slot's prompt length
|
||||||
similarity = static_cast<float>(lcs_len) / slot_prompt_len;
|
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
||||||
|
|
||||||
// select the current slot if the criteria match
|
// select the current slot if the criteria match
|
||||||
if (lcs_len > max_lcs_len && similarity > lcs_similarity) {
|
if (lcp_len > max_lcp_len && similarity > lcp_similarity) {
|
||||||
max_lcs_len = lcs_len;
|
max_lcp_len = lcp_len;
|
||||||
ret = &slot;
|
ret = &slot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret != nullptr) {
|
if (ret != nullptr) {
|
||||||
LOG_VERBOSE("selected slot by lcs similarity", {
|
LOG_VERBOSE("selected slot by lcp similarity", {
|
||||||
{"id_slot", ret->id},
|
{"id_slot", ret->id},
|
||||||
{"max_lcs_len", max_lcs_len},
|
{"max_lcp_len", max_lcp_len},
|
||||||
{"similarity", similarity},
|
{"similarity", similarity},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -2568,8 +2568,8 @@ int main(int argc, char ** argv) {
|
||||||
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
|
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Longest Common Substring similarity for slot selection
|
// Longest Common Prefix similarity for slot selection
|
||||||
ctx_server.lcs_similarity = params.lcs_similarity;
|
ctx_server.lcp_similarity = params.lcp_similarity;
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
if (!ctx_server.load_model(params)) {
|
if (!ctx_server.load_model(params)) {
|
||||||
|
|
|
@ -253,6 +253,13 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t common_part(const std::string & a, const std::string & b) {
|
||||||
|
size_t i;
|
||||||
|
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||||
}
|
}
|
||||||
|
@ -646,52 +653,3 @@ static json format_error_response(const std::string & message, const enum error_
|
||||||
{"type", type_str},
|
{"type", type_str},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static int lcs_length(const std::string & str1, const std::string & str2) {
|
|
||||||
// check for empty strings
|
|
||||||
if (str1.empty() || str2.empty()) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the lengths of the input strings
|
|
||||||
int str1_len = str1.size();
|
|
||||||
int str2_len = str2.size();
|
|
||||||
|
|
||||||
// initialize the maximum length of the longest common subsequence (LCS)
|
|
||||||
int max_length = 0;
|
|
||||||
|
|
||||||
// use two rows instead of a 2D matrix to optimize space
|
|
||||||
std::vector<int> prev_row(str2_len + 1, 0);
|
|
||||||
std::vector<int> curr_row(str2_len + 1, 0);
|
|
||||||
|
|
||||||
// iterate through the characters of str1
|
|
||||||
for (int i = 1; i <= str1_len; i++) {
|
|
||||||
// iterate through the characters of str2
|
|
||||||
for (int j = 1; j <= str2_len; j++) {
|
|
||||||
// if characters at the current positions match
|
|
||||||
if (str1[i - 1] == str2[j - 1]) {
|
|
||||||
// if it's the first character of either string, set LCS length to 1
|
|
||||||
if (i == 1 || j == 1) {
|
|
||||||
curr_row[j] = 1;
|
|
||||||
} else {
|
|
||||||
// increment LCS length by 1 compared to the previous character
|
|
||||||
curr_row[j] = prev_row[j - 1] + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// update max_length if necessary
|
|
||||||
if (curr_row[j] > max_length) {
|
|
||||||
max_length = curr_row[j];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// reset LCS length if characters don't match
|
|
||||||
curr_row[j] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update the previous row for the next iteration
|
|
||||||
prev_row = curr_row;
|
|
||||||
}
|
|
||||||
|
|
||||||
// return the maximum length of the LCS
|
|
||||||
return max_length;
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue