Merge pull request #15 from SlyEcho/server_refactor

Improve long input truncation and add more verbose logging
This commit is contained in:
Randall Fitzgerald 2023-06-02 08:47:54 -04:00 committed by GitHub
commit 28cc0cdc50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -47,6 +47,27 @@ size_t find_partial_stop_string(const std::string &stop, const std::string &text
return std::string::npos;
}
static std::string debug_str(const std::string & s) {
std::string ret;
for (size_t i = 0; s[i]; i++) {
switch (s[i]) {
case '\n': ret += "\\n"; break;
case '"': ret += "\\\""; break;
default: ret += s[i]; break;
}
}
return ret;
}
template<class InputIt, class OutputIt>
static std::string tokens_to_str(llama_context * ctx, InputIt begin, OutputIt end) {
std::string ret;
for (; begin != end; (void)++begin) {
ret += llama_token_to_str(ctx, *begin);
}
return ret;
}
struct llama_server_context
{
bool stream = false;
@ -117,8 +138,22 @@ struct llama_server_context
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
if (verbose) {
fprintf(stderr,
"input truncated: {\n"
" n_ctx: %d,\n"
" n_keep: %d,\n"
" n_left: %d,\n"
" new_tokens: \"%s\",\n"
"}\n",
params.n_ctx, params.n_keep, n_left,
debug_str(tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())).c_str());
}
prompt_tokens = new_tokens;
} else {
size_t ps = prompt_tokens.size();
@ -133,6 +168,19 @@ struct llama_server_context
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
if (verbose) {
fprintf(stderr,
"prompt: {\n"
" n_past: %zu,\n"
" cached: \"%s\",\n"
" to_eval: \"%s\",\n"
"}\n",
n_past,
debug_str(tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)).c_str(),
debug_str(tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())).c_str());
}
has_next_token = true;
}
@ -154,6 +202,17 @@ struct llama_server_context
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
if (verbose) {
fprintf(stderr,
"input truncated: {\n"
" n_ctx: %d,\n"
" n_keep: %d,\n"
" n_left: %d,\n"
" new_tokens: \"%s\",\n"
"}\n",
params.n_ctx, params.n_keep, n_left,
debug_str(tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())).c_str());
}
}
while (n_past < embd.size())
@ -339,8 +398,8 @@ struct llama_server_context
" num_tokens_predicted: %ld,\n"
" stopping_word: \"%s\",\n"
"}\n",
token, token_text.c_str(), has_next_token, n_remain, num_tokens_predicted,
stopping_word.c_str());
token, debug_str(llama_token_to_str(ctx, token)).c_str(), has_next_token, n_remain, num_tokens_predicted,
debug_str(stopping_word).c_str());
}
return token_text;
@ -710,10 +769,10 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
json tmp = format_generation_settings(llama);
fprintf(stderr,
"-------------------------\n"
"/completion parameters: %s\n"
"PROMPT[%s]\n",
"completion parameters: %s\n"
"full prompt: \"%s\"\n",
tmp.dump(4, ' ', false, json::error_handler_t::replace).c_str(),
llama.params.prompt.c_str());
debug_str(llama.params.prompt).c_str());
}
return true;