From 046a469d11a3e86391c4fd2c0423271d68f22d78 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sat, 18 Nov 2023 09:37:31 -0700 Subject: [PATCH] Fix(ish?) prompt tokenizing Automatically clear completed sequences out of the KV cache --- examples/simple-inference/simple-inference.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/simple-inference/simple-inference.cpp b/examples/simple-inference/simple-inference.cpp index 961792214..022b69baa 100644 --- a/examples/simple-inference/simple-inference.cpp +++ b/examples/simple-inference/simple-inference.cpp @@ -305,12 +305,12 @@ bool gen_ctx::init_model() { } bool gen_ctx::init_prompt() { - const bool add_bos = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = llama_should_add_bos_token(model); LOG("add_bos: %d\n", add_bos); if (!params.prompt.empty()) { LOG("tokenize the prompt\n"); - prompt_tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + prompt_tokens = ::llama_tokenize(ctx, params.prompt, add_bos, true); } LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); @@ -578,6 +578,7 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { sctx.chunks.back().tokens.push_back(sctx.last_sampled); if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) { sctx.state = SEQ_DONE; + llama_kv_cache_seq_rm(ctx, sctx.seq_id, -1, -1); sctx.batch_idx = -1; // printf(" [end of text]\n"); // break;