Merge pull request #8 from SlyEcho/server_refactor

Change how the token buffers work.
This commit is contained in:
Randall Fitzgerald 2023-05-31 18:03:40 -04:00 committed by GitHub
commit 8478e59b08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -14,6 +14,12 @@ struct server_params
bool verbose = false; bool verbose = false;
}; };
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
return i;
}
struct llama_server_context struct llama_server_context
{ {
bool stream = false; bool stream = false;
@ -28,10 +34,7 @@ struct llama_server_context
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens; std::vector<llama_token> last_n_tokens;
std::vector<llama_token> processed_tokens;
std::vector<llama_token> embd_inp;
std::vector<llama_token> last_prompt_tokens;
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
gpt_params params; gpt_params params;
@ -55,11 +58,10 @@ struct llama_server_context
generated_text.reserve(params.n_ctx); generated_text.reserve(params.n_ctx);
stopping_word = ""; stopping_word = "";
//processed_tokens.clear();
embd_inp.clear();
n_remain = 0; n_remain = 0;
n_past = 0; n_past = 0;
n_consumed = 0; n_consumed = 0;
last_n_tokens.clear();
} }
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params &params_)
@ -80,58 +82,61 @@ struct llama_server_context
bool loadPrompt() { bool loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
if (prompt_tokens == last_prompt_tokens)
{ if (params.n_keep < 0) {
embd.clear(); params.n_keep = (int)prompt_tokens.size();
} }
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
prompt_tokens = new_tokens;
}
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) { n_past = common_part(embd, prompt_tokens);
if (prompt_tokens[n_past] != processed_tokens[n_past]) { embd = prompt_tokens;
break; if (n_past == prompt_tokens.size()) {
// we have to evaluate at least 1 token to generate logits.
n_past--;
} }
}
processed_tokens.resize(n_past);
if (prompt_tokens.size() > n_past) {
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end());
}
last_prompt_tokens = prompt_tokens;
has_next_token = true; has_next_token = true;
return true; return true;
} }
void beginCompletion() void beginCompletion()
{ {
if(n_remain == 0) {
// number of tokens to keep when resetting context // number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
{
params.n_keep = (int)embd_inp.size();
}
}
n_remain = params.n_predict; n_remain = params.n_predict;
llama_set_rng_seed(ctx, params.seed); llama_set_rng_seed(ctx, params.seed);
} }
llama_token nextToken() { llama_token nextToken() {
llama_token result = -1; llama_token result = -1;
if (embd.size() > 0)
{ if (embd.size() >= (size_t)params.n_ctx) {
if (n_past + embd.size() > (size_t)params.n_ctx)
{
// Reset context // Reset context
const int n_left = n_past - params.n_keep; const int n_left = (params.n_ctx - params.n_keep)/2;
n_past = std::max(1, params.n_keep);
//processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end()); std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size()); new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
} }
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
while (n_past < embd.size())
{ {
int n_eval = (int)embd.size() - i; int n_eval = (int)embd.size() - n_past;
if (n_eval > params.n_batch) if (n_eval > params.n_batch)
{ {
n_eval = params.n_batch; n_eval = params.n_batch;
} }
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
{ {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false; has_next_token = false;
@ -139,10 +144,7 @@ struct llama_server_context
} }
n_past += n_eval; n_past += n_eval;
} }
}
embd.clear();
if (embd_inp.size() <= n_consumed)
{
// out of user input, sample next token // out of user input, sample next token
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
@ -224,7 +226,6 @@ struct llama_server_context
} }
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
processed_tokens.push_back(id);
num_tokens_predicted++; num_tokens_predicted++;
} }
@ -233,23 +234,6 @@ struct llama_server_context
result = id; result = id;
// decrement remaining sampling budget // decrement remaining sampling budget
--n_remain; --n_remain;
}
else
{
// some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > n_consumed)
{
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]);
processed_tokens.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int)embd.size() >= params.n_batch)
{
break;
}
}
}
if (!embd.empty() && embd.back() == llama_token_eos()) { if (!embd.empty() && embd.back() == llama_token_eos()) {
stopping_word = llama_token_to_str(ctx, embd.back()); stopping_word = llama_token_to_str(ctx, embd.back());