run : fix BOS being added to each message

Porting the fix from simple-chat.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
Eric Curtin 2025-01-19 17:52:20 +00:00
parent b9daaffe02
commit 0cc8e0224e

View file

@ -729,11 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
// Function to tokenize the prompt // Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) { std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data,
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); const bool is_first) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens); prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(),
true) < 0) { llama_get_kv_cache_used_cells(llama_data.context.get()) == 0, true) < 0) {
printe("failed to tokenize the prompt\n"); printe("failed to tokenize the prompt\n");
return -1; return -1;
} }
@ -774,11 +775,11 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
} }
// helper function to evaluate a prompt and generate a response // helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response, const bool is_first) {
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (tokenize_prompt(vocab, prompt, tokens) < 0) { if (tokenize_prompt(vocab, prompt, tokens, llama_data, is_first) < 0) {
return 1; return 1;
} }
@ -852,13 +853,13 @@ static int read_user_input(std::string & user_input) {
// Function to generate a response based on the prompt // Function to generate a response based on the prompt
static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response, static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response,
const bool stdout_a_terminal) { const bool stdout_a_terminal, const int prev_len) {
// Set response color // Set response color
if (stdout_a_terminal) { if (stdout_a_terminal) {
printf("\033[33m"); printf("\033[33m");
} }
if (generate(llama_data, prompt, response)) { if (generate(llama_data, prompt, response, prev_len == 0)) {
printe("failed to generate response\n"); printe("failed to generate response\n");
return 1; return 1;
} }
@ -948,7 +949,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response; std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { if (generate_response(llama_data, prompt, response, stdout_a_terminal, prev_len)) {
return 1; return 1;
} }