improved token gen logic and limits

This commit is contained in:
FSSRepo 2023-10-06 18:22:07 -04:00
parent c1ac53fbdb
commit a8435c3e32

View file

@ -80,36 +80,33 @@ struct llama_client_slot
int32_t n_decoded = 0; int32_t n_decoded = 0;
int32_t i_batch = -1; int32_t i_batch = -1;
string prompt = ""; string prompt = "";
string sampled_token_str;
string generated_text = ""; string generated_text = "";
int n_tokens_predicted = 0;
llama_token sampled; llama_token sampled;
std::vector<string> sampled_tokens;
std::vector<llama_token> tokens_prev; std::vector<llama_token> tokens_prev;
slot_state state = IDLE; slot_state state = IDLE;
slot_command command = NONE; slot_command command = NONE;
bool newToken = false;
float temperature = 0.1f; float temperature = 0.1f;
void start(string prompt_, float temp_) { void start(string prompt_, float temp_) {
prompt = prompt_; prompt = prompt_;
command = LOAD_PROMPT; command = LOAD_PROMPT;
temperature = temp_; temperature = temp_;
newToken = false; LOG_TEE("slot %i is processing\n", id);
} }
bool hasNewToken() { bool hasNewToken() {
if(newToken) { return sampled_tokens.size() > 0;
newToken = false;
return true;
}
return false;
} }
bool available() { bool available() {
return state == IDLE && command == NONE; return state == IDLE && command == NONE;
} }
void nofity() { void addTokenString(string token) {
newToken = !newToken; sampled_tokens.insert(sampled_tokens.begin(), token);
n_tokens_predicted++;
} }
void release() { void release() {
@ -163,7 +160,7 @@ struct server_parallel_context {
slot.id = i; slot.id = i;
slot.prompt = "default"; slot.prompt = "default";
slot.state = IDLE; slot.state = IDLE;
slot.tokens_prev.resize(std::max(256, params.n_predict)); slot.tokens_prev.resize(params.n_predict);
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
LOG_TEE(" - slot %i\n", slot.id); LOG_TEE(" - slot %i\n", slot.id);
slots.push_back(slot); slots.push_back(slot);
@ -247,7 +244,6 @@ struct server_parallel_context {
if ((slot_id == -1 && slot.available()) || slot.id == slot_id) if ((slot_id == -1 && slot.available()) || slot.id == slot_id)
{ {
slot.start(prompt, temperature); slot.start(prompt, temperature);
LOG_TEE("slot %i is processing\n", slot.id);
return &slot; // return a pointer to slot (thread safe?) return &slot; // return a pointer to slot (thread safe?)
} }
} }
@ -302,6 +298,7 @@ struct server_parallel_context {
} }
batch.n_tokens = 0; batch.n_tokens = 0;
int kv_cache_free = (n_ctx - n_tokens_system);
// decode any currently ongoing sequences // decode any currently ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
@ -311,13 +308,17 @@ struct server_parallel_context {
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.state = IDLE; slot.state = IDLE;
slot.command = NONE; slot.command = NONE;
slot.n_prompt = 0;
slot.n_tokens_predicted = 0;
continue; continue;
} }
kv_cache_free -= slot.n_prompt;
// no decode wait until the token had been send to client // no decode wait until the token had been send to client
// improves performance and avoid decoherence? // improves performance and avoid decoherence?
if (slot.state == IDLE || slot.newToken) { if (slot.state == IDLE) {
continue; continue;
} }
@ -339,12 +340,13 @@ struct server_parallel_context {
if (slot.state == IDLE && slot.command == LOAD_PROMPT) { if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING; slot.state = PROCESSING;
slot.command = NONE; slot.command = NONE;
//LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
// do not prepend BOS because we have a system prompt! // do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt; std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false); tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false);
slot.n_tokens_predicted = 0;
for (size_t i = 0; i < tokens_prompt.size(); ++i) { for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i]; batch.token [batch.n_tokens] = tokens_prompt[i];
@ -362,11 +364,6 @@ struct server_parallel_context {
slot.n_prompt = tokens_prompt.size(); slot.n_prompt = tokens_prompt.size();
slot.n_decoded = 0; slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1; slot.i_batch = batch.n_tokens - 1;
// insert new requests one-by-one
//if (cont_batching) {
// break;
//}
} }
} }
} }
@ -379,13 +376,6 @@ struct server_parallel_context {
int32_t n_batch = params.n_batch; int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
// i -= n_batch;
// continue;
//}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = { llama_batch batch_view = {
@ -431,18 +421,17 @@ struct server_parallel_context {
slot.sampled = id; slot.sampled = id;
size_t stop_pos = size_t stop_pos =
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL); findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
slot.sampled_token_str = token_str; slot.addTokenString(token_str);
// notify new token
slot.nofity(); kv_cache_free -= slot.n_tokens_predicted;
if (slot.n_decoded > 2 && if (slot.n_decoded > 2 &&
(id == llama_token_eos(ctx) || (id == llama_token_eos(ctx) ||
(params.n_predict > 0 && (slot.n_decoded + slot.n_prompt >=
slot.n_decoded + slot.n_prompt >= params.n_predict) ||
params.n_predict) || stop_pos != std::string::npos)) {
stop_pos != std::string::npos)) {
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str()); //LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
slot.generated_text.clear(); slot.generated_text.clear();
slot.release(); slot.release();
@ -450,6 +439,11 @@ struct server_parallel_context {
slot.i_batch = -1; slot.i_batch = -1;
} }
} }
if(kv_cache_free < 0) {
LOG_TEE("\nError: kv cache is full, increase context size.");
return false;
}
return true; return true;
} }
}; };
@ -759,6 +753,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
if(params.n_predict <= 128) { // this example don't support long prompts
params.n_predict = 128;
}
} else if (arg == "-r" || arg == "--reverse-prompt") } else if (arg == "-r" || arg == "--reverse-prompt")
{ {
if (++i >= argc) if (++i >= argc)
@ -858,7 +855,8 @@ int main(int argc, char **argv)
} }
if(slot->hasNewToken()) { // new token notification if(slot->hasNewToken()) { // new token notification
stringstream ss; stringstream ss;
json res_d = {{ "content", slot->sampled_token_str }}; json res_d = {{ "content", slot->sampled_tokens.back() }};
slot->sampled_tokens.pop_back();
ss << "data: " << res_d.dump() << "\n\n"; ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str(); string result = ss.str();
if(!sink.write(result.c_str(), result.size())) { if(!sink.write(result.c_str(), result.size())) {