diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cd463b405..038fd16c8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -47,6 +47,27 @@ size_t find_partial_stop_string(const std::string &stop, const std::string &text return std::string::npos; } +static std::string debug_str(const std::string & s) { + std::string ret; + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\n': ret += "\\n"; break; + case '"': ret += "\\\""; break; + default: ret += s[i]; break; + } + } + return ret; +} + +template +static std::string tokens_to_str(llama_context * ctx, InputIt begin, OutputIt end) { + std::string ret; + for (; begin != end; (void)++begin) { + ret += llama_token_to_str(ctx, *begin); + } + return ret; +} + struct llama_server_context { bool stream = false; @@ -117,8 +138,22 @@ struct llama_server_context if (prompt_tokens.size() >= (size_t)params.n_ctx) { const int n_left = (params.n_ctx - params.n_keep)/2; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end()); + const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + if (verbose) { + fprintf(stderr, + "input truncated: {\n" + " n_ctx: %d,\n" + " n_keep: %d,\n" + " n_left: %d,\n" + " new_tokens: \"%s\",\n" + "}\n", + params.n_ctx, params.n_keep, n_left, + debug_str(tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())).c_str()); + } + prompt_tokens = new_tokens; } else { size_t ps = prompt_tokens.size(); @@ -133,6 +168,19 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. n_past--; } + + if (verbose) { + fprintf(stderr, + "prompt: {\n" + " n_past: %zu,\n" + " cached: \"%s\",\n" + " to_eval: \"%s\",\n" + "}\n", + n_past, + debug_str(tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)).c_str(), + debug_str(tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())).c_str()); + } + has_next_token = true; } @@ -154,6 +202,17 @@ struct llama_server_context new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); embd = new_tokens; n_past = params.n_keep; + if (verbose) { + fprintf(stderr, + "input truncated: {\n" + " n_ctx: %d,\n" + " n_keep: %d,\n" + " n_left: %d,\n" + " new_tokens: \"%s\",\n" + "}\n", + params.n_ctx, params.n_keep, n_left, + debug_str(tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())).c_str()); + } } while (n_past < embd.size()) @@ -339,8 +398,8 @@ struct llama_server_context " num_tokens_predicted: %ld,\n" " stopping_word: \"%s\",\n" "}\n", - token, token_text.c_str(), has_next_token, n_remain, num_tokens_predicted, - stopping_word.c_str()); + token, debug_str(llama_token_to_str(ctx, token)).c_str(), has_next_token, n_remain, num_tokens_predicted, + debug_str(stopping_word).c_str()); } return token_text; @@ -710,10 +769,10 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & json tmp = format_generation_settings(llama); fprintf(stderr, "-------------------------\n" - "/completion parameters: %s\n" - "PROMPT[%s]\n", + "completion parameters: %s\n" + "full prompt: \"%s\"\n", tmp.dump(4, ' ', false, json::error_handler_t::replace).c_str(), - llama.params.prompt.c_str()); + debug_str(llama.params.prompt).c_str()); } return true;