improved token gen logic and limits
This commit is contained in:
parent
c1ac53fbdb
commit
a8435c3e32
1 changed files with 34 additions and 36 deletions
|
@ -80,36 +80,33 @@ struct llama_client_slot
|
||||||
int32_t n_decoded = 0;
|
int32_t n_decoded = 0;
|
||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
string prompt = "";
|
string prompt = "";
|
||||||
string sampled_token_str;
|
|
||||||
string generated_text = "";
|
string generated_text = "";
|
||||||
|
int n_tokens_predicted = 0;
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
|
std::vector<string> sampled_tokens;
|
||||||
std::vector<llama_token> tokens_prev;
|
std::vector<llama_token> tokens_prev;
|
||||||
slot_state state = IDLE;
|
slot_state state = IDLE;
|
||||||
slot_command command = NONE;
|
slot_command command = NONE;
|
||||||
bool newToken = false;
|
|
||||||
float temperature = 0.1f;
|
float temperature = 0.1f;
|
||||||
|
|
||||||
void start(string prompt_, float temp_) {
|
void start(string prompt_, float temp_) {
|
||||||
prompt = prompt_;
|
prompt = prompt_;
|
||||||
command = LOAD_PROMPT;
|
command = LOAD_PROMPT;
|
||||||
temperature = temp_;
|
temperature = temp_;
|
||||||
newToken = false;
|
LOG_TEE("slot %i is processing\n", id);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasNewToken() {
|
bool hasNewToken() {
|
||||||
if(newToken) {
|
return sampled_tokens.size() > 0;
|
||||||
newToken = false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool available() {
|
bool available() {
|
||||||
return state == IDLE && command == NONE;
|
return state == IDLE && command == NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void nofity() {
|
void addTokenString(string token) {
|
||||||
newToken = !newToken;
|
sampled_tokens.insert(sampled_tokens.begin(), token);
|
||||||
|
n_tokens_predicted++;
|
||||||
}
|
}
|
||||||
|
|
||||||
void release() {
|
void release() {
|
||||||
|
@ -163,7 +160,7 @@ struct server_parallel_context {
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.prompt = "default";
|
slot.prompt = "default";
|
||||||
slot.state = IDLE;
|
slot.state = IDLE;
|
||||||
slot.tokens_prev.resize(std::max(256, params.n_predict));
|
slot.tokens_prev.resize(params.n_predict);
|
||||||
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
||||||
LOG_TEE(" - slot %i\n", slot.id);
|
LOG_TEE(" - slot %i\n", slot.id);
|
||||||
slots.push_back(slot);
|
slots.push_back(slot);
|
||||||
|
@ -247,7 +244,6 @@ struct server_parallel_context {
|
||||||
if ((slot_id == -1 && slot.available()) || slot.id == slot_id)
|
if ((slot_id == -1 && slot.available()) || slot.id == slot_id)
|
||||||
{
|
{
|
||||||
slot.start(prompt, temperature);
|
slot.start(prompt, temperature);
|
||||||
LOG_TEE("slot %i is processing\n", slot.id);
|
|
||||||
return &slot; // return a pointer to slot (thread safe?)
|
return &slot; // return a pointer to slot (thread safe?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -302,6 +298,7 @@ struct server_parallel_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
|
int kv_cache_free = (n_ctx - n_tokens_system);
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
@ -311,13 +308,17 @@ struct server_parallel_context {
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
|
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
|
||||||
slot.state = IDLE;
|
slot.state = IDLE;
|
||||||
slot.command = NONE;
|
slot.command = NONE;
|
||||||
|
slot.n_prompt = 0;
|
||||||
|
slot.n_tokens_predicted = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kv_cache_free -= slot.n_prompt;
|
||||||
|
|
||||||
// no decode wait until the token had been send to client
|
// no decode wait until the token had been send to client
|
||||||
// improves performance and avoid decoherence?
|
// improves performance and avoid decoherence?
|
||||||
|
|
||||||
if (slot.state == IDLE || slot.newToken) {
|
if (slot.state == IDLE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,12 +340,13 @@ struct server_parallel_context {
|
||||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
|
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
|
||||||
slot.state = PROCESSING;
|
slot.state = PROCESSING;
|
||||||
slot.command = NONE;
|
slot.command = NONE;
|
||||||
//LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
|
|
||||||
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
||||||
|
|
||||||
// do not prepend BOS because we have a system prompt!
|
// do not prepend BOS because we have a system prompt!
|
||||||
std::vector<llama_token> tokens_prompt;
|
std::vector<llama_token> tokens_prompt;
|
||||||
tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false);
|
tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false);
|
||||||
|
slot.n_tokens_predicted = 0;
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||||
|
@ -362,11 +364,6 @@ struct server_parallel_context {
|
||||||
slot.n_prompt = tokens_prompt.size();
|
slot.n_prompt = tokens_prompt.size();
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
// insert new requests one-by-one
|
|
||||||
//if (cont_batching) {
|
|
||||||
// break;
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -379,13 +376,6 @@ struct server_parallel_context {
|
||||||
int32_t n_batch = params.n_batch;
|
int32_t n_batch = params.n_batch;
|
||||||
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
// experiment: process in powers of 2
|
|
||||||
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
|
|
||||||
// n_batch /= 2;
|
|
||||||
// i -= n_batch;
|
|
||||||
// continue;
|
|
||||||
//}
|
|
||||||
|
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
|
||||||
llama_batch batch_view = {
|
llama_batch batch_view = {
|
||||||
|
@ -431,18 +421,17 @@ struct server_parallel_context {
|
||||||
slot.sampled = id;
|
slot.sampled = id;
|
||||||
|
|
||||||
size_t stop_pos =
|
size_t stop_pos =
|
||||||
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
|
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
|
||||||
|
|
||||||
slot.sampled_token_str = token_str;
|
slot.addTokenString(token_str);
|
||||||
// notify new token
|
|
||||||
slot.nofity();
|
kv_cache_free -= slot.n_tokens_predicted;
|
||||||
|
|
||||||
if (slot.n_decoded > 2 &&
|
if (slot.n_decoded > 2 &&
|
||||||
(id == llama_token_eos(ctx) ||
|
(id == llama_token_eos(ctx) ||
|
||||||
(params.n_predict > 0 &&
|
(slot.n_decoded + slot.n_prompt >=
|
||||||
slot.n_decoded + slot.n_prompt >=
|
params.n_predict) ||
|
||||||
params.n_predict) ||
|
stop_pos != std::string::npos)) {
|
||||||
stop_pos != std::string::npos)) {
|
|
||||||
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
|
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
|
||||||
slot.generated_text.clear();
|
slot.generated_text.clear();
|
||||||
slot.release();
|
slot.release();
|
||||||
|
@ -450,6 +439,11 @@ struct server_parallel_context {
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(kv_cache_free < 0) {
|
||||||
|
LOG_TEE("\nError: kv cache is full, increase context size.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -759,6 +753,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_predict = std::stoi(argv[i]);
|
params.n_predict = std::stoi(argv[i]);
|
||||||
|
if(params.n_predict <= 128) { // this example don't support long prompts
|
||||||
|
params.n_predict = 128;
|
||||||
|
}
|
||||||
} else if (arg == "-r" || arg == "--reverse-prompt")
|
} else if (arg == "-r" || arg == "--reverse-prompt")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
@ -858,7 +855,8 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
if(slot->hasNewToken()) { // new token notification
|
if(slot->hasNewToken()) { // new token notification
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
json res_d = {{ "content", slot->sampled_token_str }};
|
json res_d = {{ "content", slot->sampled_tokens.back() }};
|
||||||
|
slot->sampled_tokens.pop_back();
|
||||||
ss << "data: " << res_d.dump() << "\n\n";
|
ss << "data: " << res_d.dump() << "\n\n";
|
||||||
string result = ss.str();
|
string result = ss.str();
|
||||||
if(!sink.write(result.c_str(), result.size())) {
|
if(!sink.write(result.c_str(), result.size())) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue