improve long input truncation

and add more verbose logging
This commit is contained in:
Henri Vasserman 2023-06-02 15:18:51 +03:00
parent 1bd52c8627
commit 3df0192804
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -47,6 +47,27 @@ size_t find_partial_stop_string(const std::string &stop, const std::string &text
return std::string::npos; return std::string::npos;
} }
static std::string debug_str(const std::string & s) {
std::string ret;
for (size_t i = 0; s[i]; i++) {
switch (s[i]) {
case '\n': ret += "\\n"; break;
case '"': ret += "\\\""; break;
default: ret += s[i]; break;
}
}
return ret;
}
template<class InputIt, class OutputIt>
static std::string tokens_to_str(llama_context * ctx, InputIt begin, OutputIt end) {
std::string ret;
for (; begin != end; (void)++begin) {
ret += llama_token_to_str(ctx, *begin);
}
return ret;
}
struct llama_server_context struct llama_server_context
{ {
bool stream = false; bool stream = false;
@ -117,8 +138,22 @@ struct llama_server_context
if (prompt_tokens.size() >= (size_t)params.n_ctx) { if (prompt_tokens.size() >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep)/2; const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end()); const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
if (verbose) {
fprintf(stderr,
"input truncated: {\n"
" n_ctx: %d,\n"
" n_keep: %d,\n"
" n_left: %d,\n"
" new_tokens: \"%s\",\n"
"}\n",
params.n_ctx, params.n_keep, n_left,
debug_str(tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())).c_str());
}
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
} else { } else {
size_t ps = prompt_tokens.size(); size_t ps = prompt_tokens.size();
@ -133,6 +168,19 @@ struct llama_server_context
// we have to evaluate at least 1 token to generate logits. // we have to evaluate at least 1 token to generate logits.
n_past--; n_past--;
} }
if (verbose) {
fprintf(stderr,
"prompt: {\n"
" n_past: %zu,\n"
" cached: \"%s\",\n"
" to_eval: \"%s\",\n"
"}\n",
n_past,
debug_str(tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)).c_str(),
debug_str(tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())).c_str());
}
has_next_token = true; has_next_token = true;
} }
@ -154,6 +202,17 @@ struct llama_server_context
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens; embd = new_tokens;
n_past = params.n_keep; n_past = params.n_keep;
if (verbose) {
fprintf(stderr,
"input truncated: {\n"
" n_ctx: %d,\n"
" n_keep: %d,\n"
" n_left: %d,\n"
" new_tokens: \"%s\",\n"
"}\n",
params.n_ctx, params.n_keep, n_left,
debug_str(tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())).c_str());
}
} }
while (n_past < embd.size()) while (n_past < embd.size())
@ -339,8 +398,8 @@ struct llama_server_context
" num_tokens_predicted: %ld,\n" " num_tokens_predicted: %ld,\n"
" stopping_word: \"%s\",\n" " stopping_word: \"%s\",\n"
"}\n", "}\n",
token, token_text.c_str(), has_next_token, n_remain, num_tokens_predicted, token, debug_str(llama_token_to_str(ctx, token)).c_str(), has_next_token, n_remain, num_tokens_predicted,
stopping_word.c_str()); debug_str(stopping_word).c_str());
} }
return token_text; return token_text;
@ -710,10 +769,10 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
json tmp = format_generation_settings(llama); json tmp = format_generation_settings(llama);
fprintf(stderr, fprintf(stderr,
"-------------------------\n" "-------------------------\n"
"/completion parameters: %s\n" "completion parameters: %s\n"
"PROMPT[%s]\n", "full prompt: \"%s\"\n",
tmp.dump(4, ' ', false, json::error_handler_t::replace).c_str(), tmp.dump(4, ' ', false, json::error_handler_t::replace).c_str(),
llama.params.prompt.c_str()); debug_str(llama.params.prompt).c_str());
} }
return true; return true;