From 36083dca2c92c92c2d36acaaf766872d0e37730c Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Fri, 7 Jun 2024 14:09:11 +0000 Subject: [PATCH] Use Longest Common Prefix (LCP) instead of LCS --- common/common.cpp | 8 +++--- common/common.h | 2 +- examples/server/server.cpp | 26 +++++++++--------- examples/server/utils.hpp | 56 +++++--------------------------------- 4 files changed, 25 insertions(+), 67 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 42c594bbd..65448c918 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1460,12 +1460,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chat_template = argv[i]; return true; } - if (arg == "--lcs-similarity") { + if (arg == "--lcp-similarity") { if (++i >= argc) { invalid_param = true; return true; } - params.lcs_similarity = std::stof(argv[i]); + params.lcp_similarity = std::stof(argv[i]); return true; } if (arg == "-pps") { @@ -1839,8 +1839,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); - options.push_back({ "server", " --lcs-similarity SIMILARITY", - "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcs_similarity }); + options.push_back({ "server", " --lcp-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcp_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); diff --git a/common/common.h b/common/common.h index 0c9c592a4..0a8a9c073 100644 --- a/common/common.h +++ b/common/common.h @@ -202,7 +202,7 @@ struct gpt_params { std::string slot_save_path; - float lcs_similarity = 0.0f; + float lcp_similarity = 0.0f; // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b90a0b8f3..802c660c7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -647,8 +647,8 @@ struct server_context { server_metrics metrics; - // Longest Common Substring similarity for slot selection - float lcs_similarity = 0.0f; + // Longest Common Prefix similarity for slot selection + float lcp_similarity = 0.0f; ~server_context() { if (ctx) { @@ -812,8 +812,8 @@ struct server_context { server_slot * ret = nullptr; // find the slot that has at least n% prompt similarity - if (ret == nullptr && lcs_similarity != 0.0f && !prompt.empty()) { - int max_lcs_len = 0; + if (ret == nullptr && lcp_similarity != 0.0f && !prompt.empty()) { + int max_lcp_len = 0; float similarity = 0; for (server_slot & slot : slots) { @@ -833,23 +833,23 @@ struct server_context { // length of the current slot's prompt int slot_prompt_len = slot_prompt.size(); - // length of the longest common substring between the current slot's prompt and the input prompt - int lcs_len = lcs_length(slot_prompt, prompt); + // length of the Longest Common Prefix between the current slot's prompt and the input prompt + int lcp_len = common_part(slot_prompt, prompt); // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcs_len) / slot_prompt_len; + similarity = static_cast(lcp_len) / slot_prompt_len; // select the current slot if the criteria match - if (lcs_len > max_lcs_len && similarity > lcs_similarity) { - max_lcs_len = lcs_len; + if (lcp_len > max_lcp_len && similarity > lcp_similarity) { + max_lcp_len = lcp_len; ret = &slot; } } if (ret != nullptr) { - LOG_VERBOSE("selected slot by lcs similarity", { + LOG_VERBOSE("selected slot by lcp similarity", { {"id_slot", ret->id}, - {"max_lcs_len", max_lcs_len}, + {"max_lcp_len", max_lcp_len}, {"similarity", similarity}, }); } @@ -2568,8 +2568,8 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; } - // Longest Common Substring similarity for slot selection - ctx_server.lcs_similarity = params.lcs_similarity; + // Longest Common Prefix similarity for slot selection + ctx_server.lcp_similarity = params.lcp_similarity; // load the model if (!ctx_server.load_model(params)) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 904f5e3c0..63fde9c9f 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -253,6 +253,13 @@ static size_t common_part(const std::vector & a, const std::vector< return i; } +static size_t common_part(const std::string & a, const std::string & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } @@ -646,52 +653,3 @@ static json format_error_response(const std::string & message, const enum error_ {"type", type_str}, }; } - -static int lcs_length(const std::string & str1, const std::string & str2) { - // check for empty strings - if (str1.empty() || str2.empty()) { - return 0; - } - - // get the lengths of the input strings - int str1_len = str1.size(); - int str2_len = str2.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - int max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(str2_len + 1, 0); - std::vector curr_row(str2_len + 1, 0); - - // iterate through the characters of str1 - for (int i = 1; i <= str1_len; i++) { - // iterate through the characters of str2 - for (int j = 1; j <= str2_len; j++) { - // if characters at the current positions match - if (str1[i - 1] == str2[j - 1]) { - // if it's the first character of either string, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous character - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if characters don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -}