diff --git a/examples/server/server.cpp b/examples/server/server.cpp index beb335dfa..71d93cb33 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -14,6 +14,12 @@ struct server_params bool verbose = false; }; +static size_t common_part(const std::vector & a, const std::vector & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++); + return i; +} + struct llama_server_context { bool stream = false; @@ -28,10 +34,7 @@ struct llama_server_context std::vector embd; std::vector last_n_tokens; - std::vector processed_tokens; - std::vector embd_inp; - std::vector last_prompt_tokens; llama_context *ctx = nullptr; gpt_params params; @@ -55,11 +58,10 @@ struct llama_server_context generated_text.reserve(params.n_ctx); stopping_word = ""; - //processed_tokens.clear(); - embd_inp.clear(); n_remain = 0; n_past = 0; n_consumed = 0; + last_n_tokens.clear(); } bool loadModel(const gpt_params ¶ms_) @@ -80,176 +82,158 @@ struct llama_server_context bool loadPrompt() { params.prompt.insert(0, 1, ' '); // always add a first space std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); - if (prompt_tokens == last_prompt_tokens) - { - embd.clear(); + + if (params.n_keep < 0) { + params.n_keep = (int)prompt_tokens.size(); } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (prompt_tokens.size() >= (size_t)params.n_ctx) { + const int n_left = (params.n_ctx - params.n_keep)/2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end()); + prompt_tokens = new_tokens; + } + // compare the evaluated prompt with the new prompt - for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) { - if (prompt_tokens[n_past] != processed_tokens[n_past]) { - break; - } + n_past = common_part(embd, prompt_tokens); + embd = prompt_tokens; + if (n_past == prompt_tokens.size()) { + // we have to evaluate at least 1 token to generate logits. + n_past--; } - processed_tokens.resize(n_past); - if (prompt_tokens.size() > n_past) { - embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end()); - } - last_prompt_tokens = prompt_tokens; has_next_token = true; return true; } void beginCompletion() { - if(n_remain == 0) { - // number of tokens to keep when resetting context - if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size()) - { - params.n_keep = (int)embd_inp.size(); - } - } + // number of tokens to keep when resetting context + + n_remain = params.n_predict; llama_set_rng_seed(ctx, params.seed); } llama_token nextToken() { llama_token result = -1; - if (embd.size() > 0) - { - if (n_past + embd.size() > (size_t)params.n_ctx) - { - // Reset context - const int n_left = n_past - params.n_keep; - n_past = std::max(1, params.n_keep); - //processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end()); - embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size()); - } - for (int i = 0; i < (int)embd.size(); i += params.n_batch) - { - int n_eval = (int)embd.size() - i; - if (n_eval > params.n_batch) - { - n_eval = params.n_batch; - } - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) - { - fprintf(stderr, "%s : failed to eval\n", __func__); - has_next_token = false; - return result; - } - n_past += n_eval; - } + + if (embd.size() >= (size_t)params.n_ctx) { + // Reset context + const int n_left = (params.n_ctx - params.n_keep)/2; + + std::vector new_tokens(embd.begin(), embd.begin() + params.n_keep); + new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); + embd = new_tokens; + n_past = params.n_keep; } - embd.clear(); - if (embd_inp.size() <= n_consumed) + + while (n_past < embd.size()) { - // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - llama_token id = 0; + int n_eval = (int)embd.size() - n_past; + if (n_eval > params.n_batch) { - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + n_eval = params.n_batch; + } + if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) + { + fprintf(stderr, "%s : failed to eval\n", __func__); + has_next_token = false; + return result; + } + n_past += n_eval; + } - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) + // out of user input, sample next token + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + llama_token id = 0; + { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) + { + logits[it->first] += it->second; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) + { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; + + // Apply penalties + float nl_logit = logits[llama_token_nl()]; + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); + llama_sample_repetition_penalty(ctx, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + if (!penalize_nl) + { + logits[llama_token_nl()] = nl_logit; + } + + if (temp <= 0) + { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } + else + { + if (mirostat == 1) { - logits[it->first] += it->second; + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } - - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) + else if (mirostat == 2) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - - // Apply penalties - float nl_logit = logits[llama_token_nl()]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); - llama_sample_repetition_penalty(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - if (!penalize_nl) - { - logits[llama_token_nl()] = nl_logit; - } - - if (temp <= 0) - { - // Greedy sampling - id = llama_sample_token_greedy(ctx, &candidates_p); + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { - if (mirostat == 1) - { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } - else if (mirostat == 2) - { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } - else - { - // Temperature sampling - llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); - llama_sample_typical(ctx, &candidates_p, typical_p, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token(ctx, &candidates_p); - } + // Temperature sampling + llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); } - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); - processed_tokens.push_back(id); - num_tokens_predicted++; } + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + num_tokens_predicted++; + } - // add it to the context - embd.push_back(id); - result = id; - // decrement remaining sampling budget - --n_remain; - } - else - { - // some user input remains from prompt or interaction, forward it to processing - while (embd_inp.size() > n_consumed) - { - embd.push_back(embd_inp[n_consumed]); - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(embd_inp[n_consumed]); - processed_tokens.push_back(embd_inp[n_consumed]); - ++n_consumed; - if ((int)embd.size() >= params.n_batch) - { - break; - } - } - } + // add it to the context + embd.push_back(id); + result = id; + // decrement remaining sampling budget + --n_remain; if (!embd.empty() && embd.back() == llama_token_eos()) { stopping_word = llama_token_to_str(ctx, embd.back());