This commit is contained in:
coezbek 2023-10-18 06:03:08 +00:00 committed by GitHub
commit 0b2ec84246
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -436,31 +436,34 @@ struct llama_server_context
} }
params.n_keep = std::min(n_ctx - 4, params.n_keep); params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, we will truncate in the same way when the embd becomes too big when generating tokens
if (num_prompt_tokens >= (size_t)n_ctx) if (num_prompt_tokens >= (size_t)n_ctx)
{ {
const int n_left = (n_ctx - params.n_keep) / 2; const int n_left = n_ctx - params.n_keep;
// Keep n_keep tokens of start of prompt (at most n_ctx - 4)
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); // Use half the left-over space in the context for the prompt
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx}, {"n_ctx", n_ctx},
{"n_keep", params.n_keep}, {"n_keep", params.n_keep},
{"n_left", n_left}, {"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
{"num_prompt_tokens", new_tokens.size()}
}); });
truncated = true; truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
num_prompt_tokens = prompt_tokens.size();
} }
else
{ // Initialize last_n_tokens
const size_t ps = num_prompt_tokens; const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens); n_past = common_part(embd, prompt_tokens);