Merge pull request #8 from SlyEcho/server_refactor

Change how the token buffers work.
This commit is contained in:
Randall Fitzgerald 2023-05-31 18:03:40 -04:00 committed by GitHub
commit 8478e59b08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -14,6 +14,12 @@ struct server_params
bool verbose = false;
};
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
return i;
}
struct llama_server_context
{
bool stream = false;
@ -28,10 +34,7 @@ struct llama_server_context
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
std::vector<llama_token> processed_tokens;
std::vector<llama_token> embd_inp;
std::vector<llama_token> last_prompt_tokens;
llama_context *ctx = nullptr;
gpt_params params;
@ -55,11 +58,10 @@ struct llama_server_context
generated_text.reserve(params.n_ctx);
stopping_word = "";
//processed_tokens.clear();
embd_inp.clear();
n_remain = 0;
n_past = 0;
n_consumed = 0;
last_n_tokens.clear();
}
bool loadModel(const gpt_params &params_)
@ -80,176 +82,158 @@ struct llama_server_context
bool loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
if (prompt_tokens == last_prompt_tokens)
{
embd.clear();
if (params.n_keep < 0) {
params.n_keep = (int)prompt_tokens.size();
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
prompt_tokens = new_tokens;
}
// compare the evaluated prompt with the new prompt
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) {
if (prompt_tokens[n_past] != processed_tokens[n_past]) {
break;
}
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == prompt_tokens.size()) {
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
processed_tokens.resize(n_past);
if (prompt_tokens.size() > n_past) {
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end());
}
last_prompt_tokens = prompt_tokens;
has_next_token = true;
return true;
}
void beginCompletion()
{
if(n_remain == 0) {
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
{
params.n_keep = (int)embd_inp.size();
}
}
// number of tokens to keep when resetting context
n_remain = params.n_predict;
llama_set_rng_seed(ctx, params.seed);
}
llama_token nextToken() {
llama_token result = -1;
if (embd.size() > 0)
{
if (n_past + embd.size() > (size_t)params.n_ctx)
{
// Reset context
const int n_left = n_past - params.n_keep;
n_past = std::max(1, params.n_keep);
//processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
}
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
{
int n_eval = (int)embd.size() - i;
if (n_eval > params.n_batch)
{
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false;
return result;
}
n_past += n_eval;
}
if (embd.size() >= (size_t)params.n_ctx) {
// Reset context
const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
}
embd.clear();
if (embd_inp.size() <= n_consumed)
while (n_past < embd.size())
{
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
int n_eval = (int)embd.size() - n_past;
if (n_eval > params.n_batch)
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false;
return result;
}
n_past += n_eval;
}
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
{
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[llama_token_nl()] = nl_logit;
}
if (temp <= 0)
{
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
}
else
{
if (mirostat == 1)
{
logits[it->first] += it->second;
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
else if (mirostat == 2)
{
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[llama_token_nl()] = nl_logit;
}
if (temp <= 0)
{
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
{
if (mirostat == 1)
{
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}
else if (mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
{
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
processed_tokens.push_back(id);
num_tokens_predicted++;
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
num_tokens_predicted++;
}
// add it to the context
embd.push_back(id);
result = id;
// decrement remaining sampling budget
--n_remain;
}
else
{
// some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > n_consumed)
{
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]);
processed_tokens.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int)embd.size() >= params.n_batch)
{
break;
}
}
}
// add it to the context
embd.push_back(id);
result = id;
// decrement remaining sampling budget
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos()) {
stopping_word = llama_token_to_str(ctx, embd.back());