server rewrite

Remove unnecessary things and radically rewrite server
This commit is contained in:
Henri Vasserman 2023-05-28 02:42:18 +03:00
parent 1f40a789e6
commit 51e09944ce
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -15,19 +15,16 @@ struct llama_server_context
bool has_next_token = false;
std::string generated_text = "";
int32_t num_tokens_predicted = 0;
int32_t n_past = 0;
int32_t n_consumed = 0;
int32_t n_session_consumed = 0;
int32_t n_remain = 0;
size_t num_tokens_predicted = 0;
size_t n_past = 0;
size_t n_consumed = 0;
size_t n_session_consumed = 0;
size_t n_remain = 0;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
std::vector<llama_token> processed_tokens;
std::vector<llama_token> llama_token_newline;
std::vector<llama_token> embd_inp;
std::vector<std::vector<llama_token>> no_show_words;
std::vector<llama_token> tokens_predicted;
std::vector<llama_token> last_prompt_tokens;
@ -37,9 +34,14 @@ struct llama_server_context
void rewind() {
as_loop = false;
params.antiprompt.clear();
no_show_words.clear();
num_tokens_predicted = 0;
generated_text = "";
//processed_tokens.clear();
embd_inp.clear();
n_remain = 0;
n_past = 0;
n_consumed = 0;
}
bool loadModel(gpt_params params_)
@ -51,8 +53,7 @@ struct llama_server_context
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return false;
}
// determine newline token
llama_token_newline = ::llama_tokenize(ctx, "\n", false);
last_n_tokens.resize(params.n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
@ -62,53 +63,14 @@ struct llama_server_context
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
// compare the evaluated prompt with the new prompt
int new_prompt_len = 0;
if (last_prompt_tokens == prompt_tokens)
{
//fprintf(stdout, "Context matched.\n");
processed_tokens = last_prompt_tokens;
embd_inp = last_prompt_tokens;
n_past = processed_tokens.size();
n_consumed = last_prompt_tokens.size() - 2;
has_next_token = true;
return true;
}
else
{
if (!processed_tokens.empty() && !embd_inp.empty())
{
//fprintf(stdout, "Resetting context.\n");
processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end());
embd_inp.erase(embd_inp.begin() + 1, embd_inp.end());
n_consumed = 0;
n_past = 0;
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) {
if (prompt_tokens[n_past] != processed_tokens[n_past]) {
break;
}
}
for (size_t i = 0; i < prompt_tokens.size(); i++) {
if (i < processed_tokens.size() &&
processed_tokens[i] == prompt_tokens[i])
{
continue;
}
else
{
embd_inp.push_back(prompt_tokens[i]);
if(new_prompt_len == 0) {
if(int32_t(i) - 1 < n_past) {
processed_tokens.erase(processed_tokens.begin() + i, processed_tokens.end());
}
// Evaluate the new fragment prompt from the last token processed.
n_past = processed_tokens.size();
}
new_prompt_len ++;
}
}
if(n_past > 0 && params.interactive) {
n_remain -= new_prompt_len;
}
if ((int)embd_inp.size() > params.n_ctx - 4)
{
return false;
processed_tokens.resize(n_past);
if (prompt_tokens.size() > n_past) {
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end());
}
last_prompt_tokens = prompt_tokens;
has_next_token = true;
@ -131,7 +93,7 @@ struct llama_server_context
llama_token result = -1;
if (embd.size() > 0)
{
if (n_past + (int)embd.size() > params.n_ctx)
if (n_past + embd.size() > (size_t)params.n_ctx)
{
// Reset context
const int n_left = n_past - params.n_keep;
@ -156,7 +118,7 @@ struct llama_server_context
}
}
embd.clear();
if ((int)embd_inp.size() <= n_consumed && has_next_token)
if (embd_inp.size() <= n_consumed)
{
// out of user input, sample next token
const float temp = params.temp;
@ -243,18 +205,6 @@ struct llama_server_context
num_tokens_predicted++;
}
// replace end of text token with newline token when in interactive mode
if (id == llama_token_eos() && params.interactive)
{
id = llama_token_newline.front();
if (params.antiprompt.size() != 0)
{
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
}
}
// add it to the context
embd.push_back(id);
result = id;
@ -264,7 +214,7 @@ struct llama_server_context
else
{
// some user input remains from prompt or interaction, forward it to processing
while ((int)embd_inp.size() > n_consumed)
while (embd_inp.size() > n_consumed)
{
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
@ -277,41 +227,11 @@ struct llama_server_context
}
}
}
if (params.interactive && (int)embd_inp.size() <= n_consumed)
{
// check for reverse prompt
if (params.antiprompt.size())
{
std::string last_output;
for (auto id : last_n_tokens)
{
last_output += llama_token_to_str(ctx, id);
}
has_next_token = true;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string &antiprompt : params.antiprompt)
{
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos)
{
has_next_token = false;
return result;
}
}
}
if (n_past > 0)
{
has_next_token = true;
}
}
if (!embd.empty() && embd.back() == llama_token_eos()) {
has_next_token = false;
}
if (params.interactive && n_remain <= 0 && params.n_predict != -1)
{
n_remain = params.n_predict;
}
has_next_token = n_remain != 0;
return result;
}
@ -322,58 +242,23 @@ struct llama_server_context
if (token == -1) {
return "";
}
tokens_predicted.clear();
tokens_predicted.push_back(token);
// Avoid add the no show words to the response
for (std::vector<llama_token> word_tokens : no_show_words)
{
size_t match_token = 1;
if (tokens_predicted.front() == word_tokens.front())
{
bool execute_matching = true;
if (tokens_predicted.size() > 1) { // if previus tokens had been tested
for (size_t i = 1; i < word_tokens.size(); i++)
{
if (i >= tokens_predicted.size()) {
match_token = i;
break;
}
if (tokens_predicted[i] == word_tokens[i])
{
continue;
}
else
{
execute_matching = false;
break;
}
}
}
while (execute_matching) {
if (match_token == word_tokens.size()) {
return "";
}
token = nextToken();
tokens_predicted.push_back(token);
if (token == word_tokens[match_token])
{ // the token follow the sequence
match_token++;
}
else if (match_token < word_tokens.size())
{ // no complete all word sequence
break;
}
}
}
}
if(as_loop) {
generated_text = "";
}
for (llama_token tkn : tokens_predicted)
{
generated_text += llama_token_to_str(ctx, tkn);
std::string token_text = llama_token_to_str(ctx, token);
generated_text += token_text;
for (std::string word : params.antiprompt) {
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
if (i != std::string::npos) {
generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size());
has_next_token = false;
break;
}
}
return generated_text;
}
@ -616,10 +501,6 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
{
llama.as_loop = body["as_loop"].get<bool>();
}
if (!body["interactive"].is_null())
{
llama.params.interactive = body["interactive"].get<bool>();
}
if (!body["prompt"].is_null())
{
llama.params.prompt = body["prompt"].get<std::string>();
@ -635,20 +516,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
}
if (!body["stop"].is_null())
{
std::vector<std::string> stop_words = body["stop"].get<std::vector<std::string>>();
for (std::string stop_word : stop_words)
{
llama.params.antiprompt.push_back(stop_word);
llama.no_show_words.push_back(::llama_tokenize(llama.ctx, stop_word, false));
}
}
if (!body["exclude"].is_null())
{
std::vector<std::string> no_show_words = body["exclude"].get<std::vector<std::string>>();
for (std::string no_show : no_show_words)
{
llama.no_show_words.push_back(::llama_tokenize(llama.ctx, no_show, false));
}
llama.params.antiprompt = body["stop"].get<std::vector<std::string>>();
}
return true;
}