diff --git a/examples/main/chat.hpp b/examples/main/chat.hpp index 801e7acdc..e17cbfa02 100644 --- a/examples/main/chat.hpp +++ b/examples/main/chat.hpp @@ -10,6 +10,7 @@ struct llama_cli_chat { struct llama_context * ctx; const struct llama_model * model; + const struct llama_vocab * vocab; struct common_sampler * smpl; struct common_params params; @@ -26,6 +27,7 @@ struct llama_cli_chat { struct llama_context * ctx, struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) { model = llama_get_model(ctx); + vocab = llama_model_get_vocab(model); batch = llama_batch_init(params.n_batch, 0, 1); } @@ -130,7 +132,7 @@ struct llama_cli_chat { new_token_id = common_sampler_sample(smpl, ctx, -1); // is it an end of generation? - if (llama_token_is_eog(model, new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { break; }