examples : fix add_special conditions (#11311)

This commit is contained in:
Georgi Gerganov 2025-01-20 16:36:08 +02:00 committed by GitHub
parent 90d987b105
commit 9f7add1cde
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 7 deletions

View file

@ -729,10 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
// Function to tokenize the prompt // Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) { std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens); prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
true) < 0) { true) < 0) {
printe("failed to tokenize the prompt\n"); printe("failed to tokenize the prompt\n");
return -1; return -1;
@ -778,7 +780,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (tokenize_prompt(vocab, prompt, tokens) < 0) { if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
return 1; return 1;
} }

View file

@ -95,13 +95,15 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
// helper function to evaluate a prompt and generate a response // helper function to evaluate a prompt and generate a response
auto generate = [&](const std::string & prompt, bool is_first) { auto generate = [&](const std::string & prompt) {
std::string response; std::string response;
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
// tokenize the prompt // tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens); std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n"); GGML_ABORT("failed to tokenize the prompt\n");
} }
@ -180,7 +182,7 @@ int main(int argc, char ** argv) {
// generate a response // generate a response
printf("\033[33m"); printf("\033[33m");
std::string response = generate(prompt, prev_len == 0); std::string response = generate(prompt);
printf("\n\033[0m"); printf("\n\033[0m");
// add the response to the messages // add the response to the messages