diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 641418405..b82784510 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -148,8 +148,9 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; json prompt; + std::vector prompt_tokens; + std::string generated_text; - llama_token sampled; std::vector cache_tokens; std::vector generated_token_probs; @@ -167,6 +168,7 @@ struct server_slot { std::string stopping_word; // sampling + llama_token sampled; struct llama_sampling_params sparams; llama_sampling_context * ctx_sampling = nullptr; @@ -181,7 +183,7 @@ struct server_slot { size_t n_sent_token_probs = 0; int64_t t_start_process_prompt; - int64_t t_start_genereration; + int64_t t_start_generation; double t_prompt_processing; // ms double t_token_generation; // ms @@ -232,13 +234,12 @@ struct server_slot { if (command == RELEASE) { return; } - cache_tokens.push_back(token.tok); generated_token_probs.push_back(token); } void release() { if (state == PROCESSING) { - t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; command = RELEASE; } } @@ -968,6 +969,7 @@ struct llama_server_context { } slot.command = LOAD_PROMPT; + slot.prompt_tokens.clear(); LOG_INFO("slot is processing task", { {"id_slot", slot.id}, @@ -1426,9 +1428,7 @@ struct llama_server_context { if (task.data.contains("system_prompt")) { system_prompt_set(task.data["system_prompt"]); - // reset cache_tokens for all slots for (server_slot & slot : slots) { - slot.cache_tokens.clear(); slot.n_past = 0; slot.n_past_se = 0; } @@ -1614,7 +1614,7 @@ struct llama_server_context { for (server_slot & slot : slots) { if (slot.ga_n == 1) { - if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) { + if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; @@ -1635,11 +1635,13 @@ struct llama_server_context { llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + if (slot.params.cache_prompt) { + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + } slot.n_past -= n_discard; @@ -1663,8 +1665,13 @@ struct llama_server_context { // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + slot.n_past += 1; + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.sampled); + } + LOG_VERBOSE("slot decode token", { {"id_slot", slot.id}, {"id_task", slot.id_task}, @@ -1681,6 +1688,8 @@ struct llama_server_context { // assign workload to the slots if (params.cont_batching || batch.n_tokens == 0) { + int n_available = n_batch; + for (auto & slot : slots) { const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()); @@ -1697,122 +1706,132 @@ struct llama_server_context { // need process the prompt if (slot.state == IDLE && slot.command == LOAD_PROMPT) { - slot.state = PROCESSING; - slot.command = NONE; + auto & prompt_tokens = slot.prompt_tokens; - std::vector prompt_tokens; - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_genereration = 0; - - if (slot.infill) { - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - - auto prefix_tokens = tokenize(slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(llama_token_middle(model)); - prompt_tokens = prefix_tokens; - } else { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt - } - - slot.n_prompt_tokens = prompt_tokens.size(); - - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it, if group attention self-extend is disabled - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { - const int n_left = slot.n_ctx - slot.params.n_keep; - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - std::vector new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + if (prompt_tokens.empty()) { + LOG_VERBOSE("tokenizing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task} }); - slot.truncated = true; - prompt_tokens = new_tokens; + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + if (slot.infill) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize(slot.params.input_prefix, false); + auto suffix_tokens = tokenize(slot.params.input_suffix, false); + + const int space_token = 29871; // TODO: this should not be hardcoded + if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(model)); + prompt_tokens = prefix_tokens; + } else { + prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt + } + + slot.n_past = 0; slot.n_prompt_tokens = prompt_tokens.size(); - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } - if (!slot.params.cache_prompt) { - llama_sampling_reset(slot.ctx_sampling); + if (slot.embedding) { + // this prompt is too large to process - discard it + if (slot.n_prompt_tokens > n_batch) { + slot.state = PROCESSING; + slot.command = NONE; + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; + } + } else { + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; + // if input prompt is too big, truncate it (if group attention self-extend is disabled) + if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { + const int n_left = slot.n_ctx - slot.params.n_keep; - slot.n_prompt_tokens_processed = slot.n_prompt_tokens; - } else { - GGML_ASSERT(slot.ga_n == 1); + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - // push the prompt into the sampling context (do not apply grammar) - for (auto & token : prompt_tokens) { - llama_sampling_accept(slot.ctx_sampling, ctx, token, false); + std::vector new_tokens( + prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); + + new_tokens.insert( + new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.end()); + + prompt_tokens = std::move(new_tokens); + + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); + + LOG_VERBOSE("input truncated", { + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, + }); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } + + llama_sampling_reset(slot.ctx_sampling); + + if (!slot.params.cache_prompt) { + slot.n_past_se = 0; + slot.ga_i = 0; + } else { + GGML_ASSERT(slot.ga_n == 1); + + slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); + + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + } } - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + // we have to evaluate at least 1 token to generate logits. + LOG_INFO("we have to evaluate at least 1 token to generate logits", { + { "id_slot", slot.id }, + { "id_task", slot.id_task } + }); - // the last token of the cache is not in the KV cache until the next call to llama_decode - // (it was sampled, pushed into the "cache_tokens", but not yet put in the context) - if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) { - slot.n_past -= 1; + slot.n_past--; + if (slot.ga_i > 0) { + slot.n_past_se--; + } } - slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past; + slot.n_prompt_tokens_processed = 0; } - LOG_INFO("slot progression", { - { "id_slot", slot.id }, - { "id_task", slot.id_task }, - { "n_past", slot.n_past }, - { "n_past_se", slot.n_past_se }, - { "ga_i", slot.ga_i }, - { "n_prompt_tokens_processed", slot.n_prompt_tokens_processed } - }); - - slot.cache_tokens = prompt_tokens; - - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", { - { "id_slot", slot.id }, - { "id_task", slot.id_task } - }); - - slot.n_past--; - if (slot.ga_i > 0) { - slot.n_past_se--; + if (slot.embedding) { + // cannot fit the prompt in the current batch - will try next iter + if (slot.n_prompt_tokens > n_available) { + continue; } } @@ -1825,19 +1844,13 @@ struct llama_server_context { { "p0", p0 } }); - LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, - {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, - }); - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; int32_t ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; - for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) { + for (; slot.n_past < slot.n_prompt_tokens && n_available > 0; ++slot.n_past, --n_available) { if (slot.ga_n != 1) { while (slot_npast >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); @@ -1848,16 +1861,45 @@ struct llama_server_context { llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + } + + slot.n_prompt_tokens_processed++; slot_npast++; } - // extract the logits only for the last token - if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; - } + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = PROCESSING; + slot.command = NONE; - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + GGML_ASSERT(batch.n_tokens > 0); + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + LOG_VERBOSE("prompt processed", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + }); + } else { + LOG_VERBOSE("prompt still processing", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + }); + } + } + + if (n_available == 0) { + break; } } } @@ -1868,6 +1910,10 @@ struct llama_server_context { return true; } + LOG_VERBOSE("decoding batch", { + {"n_tokens", batch.n_tokens}, + }); + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -1948,8 +1994,8 @@ struct llama_server_context { slot.n_decoded += 1; if (slot.n_decoded == 1) { - slot.t_start_genereration = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); }