Merge pull request #11 from SlyEcho/server_refactor

Server refactor
This commit is contained in:
Randall Fitzgerald 2023-06-01 08:10:55 -04:00 committed by GitHub
commit af711263ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -88,7 +88,6 @@ struct llama_server_context
n_remain = 0; n_remain = 0;
n_past = 0; n_past = 0;
n_consumed = 0; n_consumed = 0;
last_n_tokens.clear();
} }
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params &params_)
@ -120,7 +119,12 @@ struct llama_server_context
const int n_left = (params.n_ctx - params.n_keep)/2; const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
} else {
size_t ps = prompt_tokens.size();
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
} }
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
@ -251,10 +255,7 @@ struct llama_server_context
id = llama_sample_token(ctx, &candidates_p); id = llama_sample_token(ctx, &candidates_p);
} }
} }
if (!last_n_tokens.empty())
{
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
}
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
num_tokens_predicted++; num_tokens_predicted++;
} }
@ -654,6 +655,16 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
} else { } else {
llama.params.logit_bias.erase(llama_token_eos()); llama.params.logit_bias.erase(llama_token_eos());
} }
if (body["logit_bias"].is_array()) {
int n_vocab = llama_n_vocab(llama.ctx);
for (const auto &el : body["logit_bias"]) {
if (el.is_array() && el.size() == 2 && el[0].is_number_integer() && el[1].is_number_float()) {
llama_token tok = el[0].get<llama_token>();
if (tok < 0 || tok >= n_vocab) continue;
llama.params.logit_bias[tok] = el[1].get<float>();
}
}
}
if (!body["prompt"].is_null()) { if (!body["prompt"].is_null()) {
llama.params.prompt = body["prompt"].get<std::string>(); llama.params.prompt = body["prompt"].get<std::string>();
} else { } else {