Merge pull request #8 from SlyEcho/server_refactor
Change how the token buffers work.
This commit is contained in:
commit
8478e59b08
1 changed files with 131 additions and 147 deletions
|
@ -14,6 +14,12 @@ struct server_params
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
|
||||||
|
size_t i;
|
||||||
|
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_server_context
|
struct llama_server_context
|
||||||
{
|
{
|
||||||
bool stream = false;
|
bool stream = false;
|
||||||
|
@ -28,10 +34,7 @@ struct llama_server_context
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> last_n_tokens;
|
std::vector<llama_token> last_n_tokens;
|
||||||
std::vector<llama_token> processed_tokens;
|
|
||||||
std::vector<llama_token> embd_inp;
|
|
||||||
|
|
||||||
std::vector<llama_token> last_prompt_tokens;
|
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
|
@ -55,11 +58,10 @@ struct llama_server_context
|
||||||
generated_text.reserve(params.n_ctx);
|
generated_text.reserve(params.n_ctx);
|
||||||
stopping_word = "";
|
stopping_word = "";
|
||||||
|
|
||||||
//processed_tokens.clear();
|
|
||||||
embd_inp.clear();
|
|
||||||
n_remain = 0;
|
n_remain = 0;
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_consumed = 0;
|
n_consumed = 0;
|
||||||
|
last_n_tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadModel(const gpt_params ¶ms_)
|
bool loadModel(const gpt_params ¶ms_)
|
||||||
|
@ -80,58 +82,61 @@ struct llama_server_context
|
||||||
bool loadPrompt() {
|
bool loadPrompt() {
|
||||||
params.prompt.insert(0, 1, ' '); // always add a first space
|
params.prompt.insert(0, 1, ' '); // always add a first space
|
||||||
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
|
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
if (prompt_tokens == last_prompt_tokens)
|
|
||||||
{
|
if (params.n_keep < 0) {
|
||||||
embd.clear();
|
params.n_keep = (int)prompt_tokens.size();
|
||||||
}
|
}
|
||||||
|
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||||
|
|
||||||
|
// if input prompt is too big, truncate like normal
|
||||||
|
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
|
||||||
|
const int n_left = (params.n_ctx - params.n_keep)/2;
|
||||||
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||||
|
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
|
||||||
|
prompt_tokens = new_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) {
|
n_past = common_part(embd, prompt_tokens);
|
||||||
if (prompt_tokens[n_past] != processed_tokens[n_past]) {
|
embd = prompt_tokens;
|
||||||
break;
|
if (n_past == prompt_tokens.size()) {
|
||||||
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
|
n_past--;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
processed_tokens.resize(n_past);
|
|
||||||
if (prompt_tokens.size() > n_past) {
|
|
||||||
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end());
|
|
||||||
}
|
|
||||||
last_prompt_tokens = prompt_tokens;
|
|
||||||
has_next_token = true;
|
has_next_token = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void beginCompletion()
|
void beginCompletion()
|
||||||
{
|
{
|
||||||
if(n_remain == 0) {
|
|
||||||
// number of tokens to keep when resetting context
|
// number of tokens to keep when resetting context
|
||||||
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
|
|
||||||
{
|
|
||||||
params.n_keep = (int)embd_inp.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n_remain = params.n_predict;
|
n_remain = params.n_predict;
|
||||||
llama_set_rng_seed(ctx, params.seed);
|
llama_set_rng_seed(ctx, params.seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token nextToken() {
|
llama_token nextToken() {
|
||||||
llama_token result = -1;
|
llama_token result = -1;
|
||||||
if (embd.size() > 0)
|
|
||||||
{
|
if (embd.size() >= (size_t)params.n_ctx) {
|
||||||
if (n_past + embd.size() > (size_t)params.n_ctx)
|
|
||||||
{
|
|
||||||
// Reset context
|
// Reset context
|
||||||
const int n_left = n_past - params.n_keep;
|
const int n_left = (params.n_ctx - params.n_keep)/2;
|
||||||
n_past = std::max(1, params.n_keep);
|
|
||||||
//processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
|
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
|
||||||
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
|
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
|
||||||
|
embd = new_tokens;
|
||||||
|
n_past = params.n_keep;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
|
|
||||||
|
while (n_past < embd.size())
|
||||||
{
|
{
|
||||||
int n_eval = (int)embd.size() - i;
|
int n_eval = (int)embd.size() - n_past;
|
||||||
if (n_eval > params.n_batch)
|
if (n_eval > params.n_batch)
|
||||||
{
|
{
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
|
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
has_next_token = false;
|
has_next_token = false;
|
||||||
|
@ -139,10 +144,7 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
n_past += n_eval;
|
n_past += n_eval;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
embd.clear();
|
|
||||||
if (embd_inp.size() <= n_consumed)
|
|
||||||
{
|
|
||||||
// out of user input, sample next token
|
// out of user input, sample next token
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
||||||
|
@ -224,7 +226,6 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
processed_tokens.push_back(id);
|
|
||||||
num_tokens_predicted++;
|
num_tokens_predicted++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,23 +234,6 @@ struct llama_server_context
|
||||||
result = id;
|
result = id;
|
||||||
// decrement remaining sampling budget
|
// decrement remaining sampling budget
|
||||||
--n_remain;
|
--n_remain;
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// some user input remains from prompt or interaction, forward it to processing
|
|
||||||
while (embd_inp.size() > n_consumed)
|
|
||||||
{
|
|
||||||
embd.push_back(embd_inp[n_consumed]);
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
|
||||||
last_n_tokens.push_back(embd_inp[n_consumed]);
|
|
||||||
processed_tokens.push_back(embd_inp[n_consumed]);
|
|
||||||
++n_consumed;
|
|
||||||
if ((int)embd.size() >= params.n_batch)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
||||||
stopping_word = llama_token_to_str(ctx, embd.back());
|
stopping_word = llama_token_to_str(ctx, embd.back());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue