improved token gen logic and limits

This commit is contained in:
FSSRepo 2023-10-06 18:22:07 -04:00
parent c1ac53fbdb
commit a8435c3e32

View file

@ -80,36 +80,33 @@ struct llama_client_slot
int32_t n_decoded = 0;
int32_t i_batch = -1;
string prompt = "";
string sampled_token_str;
string generated_text = "";
int n_tokens_predicted = 0;
llama_token sampled;
std::vector<string> sampled_tokens;
std::vector<llama_token> tokens_prev;
slot_state state = IDLE;
slot_command command = NONE;
bool newToken = false;
float temperature = 0.1f;
void start(string prompt_, float temp_) {
prompt = prompt_;
command = LOAD_PROMPT;
temperature = temp_;
newToken = false;
LOG_TEE("slot %i is processing\n", id);
}
bool hasNewToken() {
if(newToken) {
newToken = false;
return true;
}
return false;
return sampled_tokens.size() > 0;
}
bool available() {
return state == IDLE && command == NONE;
}
void nofity() {
newToken = !newToken;
void addTokenString(string token) {
sampled_tokens.insert(sampled_tokens.begin(), token);
n_tokens_predicted++;
}
void release() {
@ -163,7 +160,7 @@ struct server_parallel_context {
slot.id = i;
slot.prompt = "default";
slot.state = IDLE;
slot.tokens_prev.resize(std::max(256, params.n_predict));
slot.tokens_prev.resize(params.n_predict);
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
LOG_TEE(" - slot %i\n", slot.id);
slots.push_back(slot);
@ -247,7 +244,6 @@ struct server_parallel_context {
if ((slot_id == -1 && slot.available()) || slot.id == slot_id)
{
slot.start(prompt, temperature);
LOG_TEE("slot %i is processing\n", slot.id);
return &slot; // return a pointer to slot (thread safe?)
}
}
@ -302,6 +298,7 @@ struct server_parallel_context {
}
batch.n_tokens = 0;
int kv_cache_free = (n_ctx - n_tokens_system);
// decode any currently ongoing sequences
for (auto & slot : slots) {
@ -311,13 +308,17 @@ struct server_parallel_context {
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.state = IDLE;
slot.command = NONE;
slot.n_prompt = 0;
slot.n_tokens_predicted = 0;
continue;
}
kv_cache_free -= slot.n_prompt;
// no decode wait until the token had been send to client
// improves performance and avoid decoherence?
if (slot.state == IDLE || slot.newToken) {
if (slot.state == IDLE) {
continue;
}
@ -339,12 +340,13 @@ struct server_parallel_context {
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
//LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false);
slot.n_tokens_predicted = 0;
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i];
@ -362,11 +364,6 @@ struct server_parallel_context {
slot.n_prompt = tokens_prompt.size();
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
// insert new requests one-by-one
//if (cont_batching) {
// break;
//}
}
}
}
@ -379,13 +376,6 @@ struct server_parallel_context {
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
// i -= n_batch;
// continue;
//}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
@ -431,18 +421,17 @@ struct server_parallel_context {
slot.sampled = id;
size_t stop_pos =
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
slot.sampled_token_str = token_str;
// notify new token
slot.nofity();
slot.addTokenString(token_str);
kv_cache_free -= slot.n_tokens_predicted;
if (slot.n_decoded > 2 &&
(id == llama_token_eos(ctx) ||
(params.n_predict > 0 &&
slot.n_decoded + slot.n_prompt >=
params.n_predict) ||
stop_pos != std::string::npos)) {
(id == llama_token_eos(ctx) ||
(slot.n_decoded + slot.n_prompt >=
params.n_predict) ||
stop_pos != std::string::npos)) {
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
slot.generated_text.clear();
slot.release();
@ -450,6 +439,11 @@ struct server_parallel_context {
slot.i_batch = -1;
}
}
if(kv_cache_free < 0) {
LOG_TEE("\nError: kv cache is full, increase context size.");
return false;
}
return true;
}
};
@ -759,6 +753,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_predict = std::stoi(argv[i]);
if(params.n_predict <= 128) { // this example don't support long prompts
params.n_predict = 128;
}
} else if (arg == "-r" || arg == "--reverse-prompt")
{
if (++i >= argc)
@ -858,7 +855,8 @@ int main(int argc, char **argv)
}
if(slot->hasNewToken()) { // new token notification
stringstream ss;
json res_d = {{ "content", slot->sampled_token_str }};
json res_d = {{ "content", slot->sampled_tokens.back() }};
slot->sampled_tokens.pop_back();
ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str();
if(!sink.write(result.c_str(), result.size())) {