Merge remote-tracking branch 'origin/master' into bins

This commit is contained in:
Olivier Chafik 2024-06-08 12:04:52 +01:00
commit fe93cc96cc
11 changed files with 219 additions and 95 deletions

View file

@ -279,7 +279,7 @@ node index.js
`id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1`
`cache_prompt`: Re-use previously cached prompt from the last request if possible. This may prevent re-caching the prompt from scratch. Default: `false`
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `false`
`system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)

View file

@ -647,6 +647,9 @@ struct server_context {
server_metrics metrics;
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
~server_context() {
if (ctx) {
llama_free(ctx);
@ -795,24 +798,88 @@ struct server_context {
return prompt_tokens;
}
server_slot * get_slot(int id) {
int64_t t_last = ggml_time_us();
server_slot * last_used = nullptr;
server_slot * get_slot_by_id(int id) {
for (server_slot & slot : slots) {
if (slot.id == id && slot.available()) {
if (slot.id == id) {
return &slot;
}
// among all available slots, find the one that has been least recently used
if (slot.available() && slot.t_last_used < t_last) {
last_used = &slot;
t_last = slot.t_last_used;
}
}
return last_used;
return nullptr;
}
server_slot * get_available_slot(const std::string & prompt) {
server_slot * ret = nullptr;
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
int max_lcp_len = 0;
float similarity = 0;
for (server_slot & slot : slots) {
// skip the slot if it is not available
if (!slot.available()) {
continue;
}
// skip the slot if it does not contains prompt
if (!slot.prompt.is_string()) {
continue;
}
// current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>();
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = common_part(slot_prompt, prompt);
// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
max_lcp_len = lcp_len;
ret = &slot;
}
}
if (ret != nullptr) {
LOG_VERBOSE("selected slot by lcp similarity", {
{"id_slot", ret->id},
{"max_lcp_len", max_lcp_len},
{"similarity", similarity},
});
}
}
// find the slot that has been least recently used
if (ret == nullptr) {
int64_t t_last = ggml_time_us();
for (server_slot & slot : slots) {
// skip the slot if it is not available
if (!slot.available()) {
continue;
}
// select the current slot if the criteria match
if (slot.t_last_used < t_last) {
t_last = slot.t_last_used;
ret = &slot;
}
}
if (ret != nullptr) {
LOG_VERBOSE("selected slot by lru", {
{"id_slot", ret->id},
{"t_last", t_last},
});
}
}
return ret;
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
@ -888,7 +955,7 @@ struct server_context {
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
// get prompt
{
if (!task.infill) {
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
@ -1515,13 +1582,29 @@ struct server_context {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
{
server_slot * slot = get_slot(json_value(task.data, "id_slot", -1));
int id_slot = json_value(task.data, "id_slot", -1);
std::string prompt = json_value(task.data, "prompt", std::string());
server_slot * slot;
if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else {
slot = get_available_slot(prompt);
}
if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
queue_tasks.defer(task);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
queue_tasks.defer(task);
break;
}
if (task.data.contains("system_prompt")) {
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
@ -1638,11 +1721,17 @@ struct server_context {
case SERVER_TASK_TYPE_SLOT_SAVE:
{
int id_slot = task.data.at("id_slot");
server_slot * slot = get_slot(id_slot);
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
queue_tasks.defer(task);
break;
}
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
@ -1673,11 +1762,17 @@ struct server_context {
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
int id_slot = task.data.at("id_slot");
server_slot * slot = get_slot(id_slot);
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
queue_tasks.defer(task);
break;
}
const int64_t t_start = ggml_time_us();
@ -1715,11 +1810,17 @@ struct server_context {
case SERVER_TASK_TYPE_SLOT_ERASE:
{
int id_slot = task.data.at("id_slot");
server_slot * slot = get_slot(id_slot);
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
queue_tasks.defer(task);
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
@ -2467,6 +2568,9 @@ int main(int argc, char ** argv) {
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
}
// Necessary similarity of prompt for slot selection
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);

View file

@ -253,6 +253,13 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
return i;
}
static size_t common_part(const std::string & a, const std::string & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}