server : fix empty prompt handling + all slots idle logic

This commit is contained in:
Georgi Gerganov 2024-03-05 14:33:12 +02:00
parent ef7eb33937
commit 134f5fec22
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -20,11 +20,11 @@
#include "completion.js.hpp"
#include "json-schema-to-grammar.mjs.hpp"
#include <cstddef>
#include <thread>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <atomic>
#include <cstddef>
#include <thread>
#include <signal.h>
using json = nlohmann::json;
@ -201,7 +201,7 @@ struct server_slot {
}
}
json get_formated_timings() {
json get_formated_timings() const {
return json {
{"prompt_n", n_prompt_tokens_processed},
{"prompt_ms", t_prompt_processing},
@ -215,6 +215,34 @@ struct server_slot {
};
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
size_t stop_pos = std::string::npos;
for (const std::string & word : params.antiprompt) {
size_t pos;
if (type == STOP_FULL) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
} else {
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
if (type == STOP_FULL) {
stopped_word = true;
stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
void print_timings() const {
char buffer[512];
@ -302,7 +330,6 @@ struct llama_server_context {
llama_batch batch;
bool clean_kv_cache = true;
bool all_slots_are_idle = false;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@ -366,9 +393,6 @@ struct llama_server_context {
}
void initialize() {
// create slots
all_slots_are_idle = true;
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
@ -531,111 +555,97 @@ struct llama_server_context {
}
// infill
if (data.count("input_prefix") != 0)
{
if (data.count("input_prefix") != 0) {
slot.params.input_prefix = data["input_prefix"];
}
else
{
} else {
slot.params.input_prefix = "";
}
if (data.count("input_suffix") != 0)
{
if (data.count("input_suffix") != 0) {
slot.params.input_suffix = data["input_suffix"];
}
else
{
} else {
slot.params.input_suffix = "";
}
if (data.count("prompt") != 0)
{
if (data.count("prompt") != 0) {
slot.prompt = data["prompt"];
}
else
{
} else {
slot.prompt = "";
}
// penalize user-provided tokens
{
slot.sparams.penalty_prompt_tokens.clear();
slot.sparams.use_penalty_prompt_tokens = false;
const auto &penalty_prompt = data.find("penalty_prompt");
if (penalty_prompt != data.end())
{
if (penalty_prompt->is_string())
{
const auto & penalty_prompt = data.find("penalty_prompt");
if (penalty_prompt != data.end()) {
if (penalty_prompt->is_string()) {
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
slot.sparams.penalty_prompt_tokens.swap(penalty_tokens);
if (slot.params.n_predict > 0)
{
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
if (slot.params.n_predict > 0) {
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
}
slot.sparams.use_penalty_prompt_tokens = true;
LOG_VERBOSE("penalty_prompt_tokens", {
{"slot_id", slot.id},
{"tokens", slot.sparams.penalty_prompt_tokens},
});
}
else if (penalty_prompt->is_array())
{
else if (penalty_prompt->is_array()) {
const auto n_tokens = penalty_prompt->size();
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
const int n_vocab = llama_n_vocab(model);
for (const auto &penalty_token : *penalty_prompt)
{
if (penalty_token.is_number_integer())
{
for (const auto & penalty_token : *penalty_prompt) {
if (penalty_token.is_number_integer()) {
const auto tok = penalty_token.get<llama_token>();
if (tok >= 0 && tok < n_vocab)
{
if (tok >= 0 && tok < n_vocab) {
slot.sparams.penalty_prompt_tokens.push_back(tok);
}
}
}
slot.sparams.use_penalty_prompt_tokens = true;
LOG_VERBOSE("penalty_prompt_tokens", {
{"slot_id", slot.id},
{"tokens", slot.sparams.penalty_prompt_tokens},
});
}
}
}
slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false))
{
if (json_value(data, "ignore_eos", false)) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array())
{
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
for (const auto &el : *logit_bias)
{
if (el.is_array() && el.size() == 2)
{
for (const auto & el : *logit_bias) {
if (el.is_array() && el.size() == 2) {
float bias;
if (el[1].is_number())
{
if (el[1].is_number()) {
bias = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>())
{
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
bias = -INFINITY;
}
else
{
} else {
continue;
}
if (el[0].is_number_integer())
{
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab)
{
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias[tok] = bias;
}
}
else if (el[0].is_string())
{
} else if (el[0].is_string()) {
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks)
{
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
}
}
@ -645,45 +655,35 @@ struct llama_server_context {
slot.params.antiprompt.clear();
const auto &stop = data.find("stop");
if (stop != data.end() && stop->is_array())
{
for (const auto &word : *stop)
{
if (!word.empty())
{
const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
slot.params.antiprompt.push_back(word);
}
}
}
const auto &samplers_sequence = data.find("samplers");
if (samplers_sequence != data.end() && samplers_sequence->is_array())
{
const auto & samplers_sequence = data.find("samplers");
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
std::vector<std::string> sampler_names;
for (const auto &sampler_name : *samplers_sequence)
{
if (sampler_name.is_string())
{
for (const auto & sampler_name : *samplers_sequence) {
if (sampler_name.is_string()) {
sampler_names.emplace_back(sampler_name);
}
}
slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
}
else
{
} else {
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
}
if (slot.ctx_sampling != nullptr)
{
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
}
slot.ctx_sampling = llama_sampling_init(slot.sparams);
llama_set_rng_seed(ctx, slot.params.seed);
slot.command = LOAD_PROMPT;
all_slots_are_idle = false;
slot.command = LOAD_PROMPT;
LOG_INFO("slot is processing task", {
{"slot_id", slot.id},
@ -694,12 +694,18 @@ struct llama_server_context {
}
void kv_cache_clear() {
LOG_VERBOSE("clearing KV cache", {});
// clear the entire KV cache
llama_kv_cache_clear(ctx);
clean_kv_cache = false;
}
void system_prompt_update() {
LOG_VERBOSE("system prompt update", {
{"system_prompt", system_prompt},
});
kv_cache_clear();
system_tokens.clear();
@ -708,13 +714,11 @@ struct llama_server_context {
llama_batch_clear(batch);
for (int i = 0; i < (int)system_tokens.size(); ++i)
{
for (int i = 0; i < (int)system_tokens.size(); ++i) {
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
{
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) {
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
@ -726,78 +730,42 @@ struct llama_server_context {
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view) != 0)
{
if (llama_decode(ctx, batch_view) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
{
for (int32_t i = 1; i < params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
}
}
LOG_TEE("system prompt updated\n");
system_need_update = false;
}
void system_prompt_notify() {
void system_prompt_process(const json & sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
LOG_VERBOSE("system prompt process", {
{"system_prompt", system_prompt},
{"name_user", name_user},
{"name_assistant", name_assistant},
});
// release all slots
for (server_slot &slot : slots)
{
for (server_slot & slot : slots) {
slot.release();
}
system_need_update = true;
}
void system_prompt_process(const json &sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
system_prompt_notify();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
const stop_type type, server_slot &slot)
{
size_t stop_pos = std::string::npos;
for (const std::string &word : slot.params.antiprompt)
{
size_t pos;
if (type == STOP_FULL)
{
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
}
else
{
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos))
{
if (type == STOP_FULL)
{
slot.stopped_word = true;
slot.stopping_word = word;
slot.has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
bool process_token(completion_token_output &result, server_slot &slot) {
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok;
@ -806,34 +774,26 @@ struct llama_server_context {
slot.generated_text += token_str;
slot.has_next_token = true;
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
{
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
// we can change penalty_prompt_tokens because it is always created from scratch each request
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
}
// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
{
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
if ((c & 0xC0) == 0x80)
{
if ((c & 0xC0) == 0x80) {
// continuation byte: 10xxxxxx
continue;
}
if ((c & 0xE0) == 0xC0)
{
if ((c & 0xE0) == 0xC0) {
// 2-byte character: 110xxxxx ...
incomplete = i < 2;
}
else if ((c & 0xF0) == 0xE0)
{
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character: 1110xxxx ...
incomplete = i < 3;
}
else if ((c & 0xF8) == 0xF0)
{
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
@ -841,57 +801,58 @@ struct llama_server_context {
break;
}
if (!incomplete)
{
if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot);
if (stop_pos != std::string::npos)
{
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
}
else
{
} else {
is_stop_full = false;
stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot);
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL);
}
// check if there is any token to predict
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0))
{
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size();
// add the token to slot queue and cache
}
slot.add_token_string(result);
if (slot.params.stream)
{
if (slot.params.stream) {
send_partial_response(slot, result);
}
}
if (incomplete)
{
if (incomplete) {
slot.has_next_token = true;
}
// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params))
{
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
slot.stopped_limit = true;
slot.has_next_token = false;
LOG_VERBOSE("stopped by limit", {
{"slot_id", slot.id},
{"n_decoded", slot.n_decoded},
{"n_predict", slot.params.n_predict},
});
}
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
{
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) {
slot.stopped_eos = true;
slot.has_next_token = false;
LOG_VERBOSE("eos token found", {});
}
@ -910,24 +871,25 @@ struct llama_server_context {
return slot.has_next_token; // continue
}
void send_error(task_server& task, const std::string &error)
{
void send_error(const task_server & task, const std::string & error) {
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.error = true;
res.result_json = { { "content", error } };
queue_results.send(res);
}
json get_formated_generation(server_slot &slot)
{
json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type));
}
@ -968,38 +930,36 @@ struct llama_server_context {
};
}
void send_partial_response(server_slot &slot, completion_token_output tkn)
{
void send_partial_response(server_slot & slot, completion_token_output tkn) {
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = false;
res.result_json = json
{
res.result_json = json {
{"content", tkn.text_to_send},
{"stop", false},
{"slot_id", slot.id},
{"multimodal", false}
};
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs_output = {};
if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
if (probs_pos < probs_stop_pos)
{
probs_output = std::vector<completion_token_output>(slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
slot.n_sent_token_probs = probs_stop_pos;
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
}
if (slot.oaicompat)
{
if (slot.oaicompat) {
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
res.result_json["model"] = slot.oaicompat_model;
}
@ -1007,16 +967,13 @@ struct llama_server_context {
queue_results.send(res);
}
void send_final_response(server_slot &slot)
{
void send_final_response(const server_slot & slot) {
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;
res.result_json = json
{
res.result_json = json {
{"content", !slot.params.stream ? slot.generated_text : ""},
{"slot_id", slot.id},
{"stop", true},
@ -1034,25 +991,24 @@ struct llama_server_context {
{"timings", slot.get_formated_timings()}
};
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs = {};
if (!slot.params.stream && slot.stopped_word)
{
if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size());
}
else
{
probs = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end() - stop_word_toks.size());
} else {
probs = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
}
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
}
if (slot.oaicompat)
{
if (slot.oaicompat) {
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
res.result_json["model"] = slot.oaicompat_model;
}
@ -1060,8 +1016,7 @@ struct llama_server_context {
queue_results.send(res);
}
void send_embedding(server_slot & slot, const llama_batch & batch)
{
void send_embedding(const server_slot & slot, const llama_batch & batch) {
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@ -1070,16 +1025,13 @@ struct llama_server_context {
const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
if (!params.embedding) {
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
res.result_json = json
{
res.result_json = json {
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
}
else
{
} else {
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
@ -1090,25 +1042,25 @@ struct llama_server_context {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
res.result_json = json
{
res.result_json = json {
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
continue;
}
}
res.result_json = json
{
res.result_json = json {
{"embedding", std::vector<float>(embd, embd + n_embd)},
};
}
}
queue_results.send(res);
}
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) {
task_server task;
task.id = task_id;
task.target_id = 0;
@ -1123,7 +1075,7 @@ struct llama_server_context {
// if there's numbers in the prompt array it will be treated as an array of tokens
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
bool numbers = false;
for (const auto& e : task.data.at("prompt")) {
for (const auto & e : task.data.at("prompt")) {
if (e.is_number()) {
numbers = true;
break;
@ -1141,10 +1093,6 @@ struct llama_server_context {
split_multiprompt_task(task_id, task);
}
} else {
// an empty prompt can make slot become buggy
if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get<std::string>().empty()) {
task.data["prompt"] = " "; // add a space so that we have one token
}
queue_tasks.post(task);
}
}
@ -1186,26 +1134,19 @@ struct llama_server_context {
}
}
void process_single_task(task_server& task)
{
void process_single_task(task_server & task) {
switch (task.type)
{
case TASK_TYPE_COMPLETION: {
server_slot * slot = get_slot(json_value(task.data, "slot_id", -1));
if (slot == nullptr)
{
if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
LOG_VERBOSE("no slot is available", {{"task_id", task.id}});
queue_tasks.defer(task);
break;
}
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
if (task.data.contains("system_prompt")) {
system_prompt_process(task.data["system_prompt"]);
// reset cache_tokens for all slots
@ -1232,10 +1173,8 @@ struct llama_server_context {
}
} break;
case TASK_TYPE_CANCEL: { // release slot linked with the task id
for (auto & slot : slots)
{
if (slot.task_id == task.target_id)
{
for (auto & slot : slots) {
if (slot.task_id == task.target_id) {
slot.release();
break;
}
@ -1339,28 +1278,60 @@ struct llama_server_context {
llama_batch_clear(batch);
if (all_slots_are_idle)
// release slots
for (auto & slot : slots) {
if (slot.command == RELEASE) {
slot.state = IDLE;
slot.command = NONE;
slot.t_last_used = ggml_time_us();
LOG_INFO("slot released", {
{"slot_id", slot.id},
{"task_id", slot.task_id},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
queue_tasks.notify_slot_changed();
}
}
{
if (system_prompt.empty() && clean_kv_cache)
{
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
bool all_slots_are_idle = true;
for (auto & slot : slots) {
if (slot.state != IDLE || slot.command != NONE) {
all_slots_are_idle = false;
break;
}
}
if (all_slots_are_idle) {
LOG_INFO("all slots are idle", {});
if (system_prompt.empty() && clean_kv_cache) {
kv_cache_clear();
}
return true;
}
}
LOG_VERBOSE("posting NEXT_RESPONSE", {});
{
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
for (server_slot & slot : slots)
{
if (slot.ga_n == 1)
{
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{
queue_tasks.post(task);
}
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) {
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
@ -1377,11 +1348,11 @@ struct llama_server_context {
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
{
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
@ -1395,32 +1366,8 @@ struct llama_server_context {
}
// decode any currently ongoing sequences
LOG_VERBOSE("decoding ongoing sequences", {});
for (auto & slot : slots)
{
// release the slot
if (slot.command == RELEASE)
{
slot.state = IDLE;
slot.command = NONE;
slot.t_last_used = ggml_time_us();
LOG_INFO("slot released", {
{"slot_id", slot.id},
{"task_id", slot.task_id},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
queue_tasks.notify_slot_changed();
continue;
}
if (slot.state == IDLE)
{
for (auto & slot : slots) {
if (slot.state == IDLE) {
continue;
}
@ -1432,22 +1379,31 @@ struct llama_server_context {
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1;
LOG_VERBOSE("slot decode token", {
{"slot_id", slot.id},
{"task_id", slot.task_id},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
// assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0)
{
for (auto & slot : slots)
{
if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
// empty prompt passed -> release the slot and send empty response
// note: infill mode allows empty prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
{
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill) {
slot.state = PROCESSING;
slot.command = NONE;
slot.release();
slot.print_timings();
send_final_response(slot);
@ -1651,9 +1607,9 @@ struct llama_server_context {
}
}
if (batch.n_tokens == 0)
{
all_slots_are_idle = true;
if (batch.n_tokens == 0) {
LOG_VERBOSE("no tokens to decode", {});
return true;
}
@ -1794,9 +1750,7 @@ struct llama_server_context {
}
};
static void server_print_usage(const char *argv0, const gpt_params &params,
const server_params &sparams)
{
static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) {
printf("usage: %s [options]\n", argv0);
printf("\n");
printf("options:\n");
@ -1882,11 +1836,10 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n");
}
static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context& llama)
{
static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params, llama_server_context & llama) {
gpt_params default_params;
server_params default_sparams;
std::string arg;
bool invalid_param = false;
@ -2510,6 +2463,7 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
inline void signal_handler(int signal) {
if (is_terminating.test_and_set()) {
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
@ -2520,8 +2474,7 @@ inline void signal_handler(int signal) {
shutdown_handler(signal);
}
int main(int argc, char **argv)
{
int main(int argc, char ** argv) {
#if SERVER_VERBOSE != 1
log_disable();
#endif
@ -2699,10 +2652,10 @@ int main(int argc, char **argv)
};
std::stringstream prometheus;
for (const auto& el : all_metrics_def.items()) {
const auto& type = el.key();
const auto& metrics_def = el.value();
for (const auto& metric_def : metrics_def) {
for (const auto & el : all_metrics_def.items()) {
const auto & type = el.key();
const auto & metrics_def = el.value();
for (const auto & metric_def : metrics_def) {
std::string name = metric_def["name"];
std::string help = metric_def["help"];
auto value = json_value(metric_def, "value", 0);