add eos_id_list to llama.cpp

This commit is contained in:
toyer 2024-06-24 12:27:02 +00:00
parent 4b65b648ce
commit 3a4d5790bf
13 changed files with 122 additions and 55 deletions

View file

@ -88,12 +88,21 @@ int main(int argc, char ** argv) {
fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return 1;
}
const int n_eos_tgt = llama_n_eos(model_tgt);
std::vector<int32_t> eos_tokens_tgt(n_eos_tgt, 0);
int32_t* eos_ptr_tgt = eos_tokens_tgt.data();
llama_token_eos(model_tgt, eos_ptr_tgt);
const int n_eos_dft = llama_n_eos(model_dft);
std::vector<int32_t> eos_tokens_dft(n_eos_dft, 0);
int32_t* eos_ptr_dft = eos_tokens_dft.data();
llama_token_eos(model_dft, eos_ptr_dft);
if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
eos_tokens_tgt != eos_tokens_dft
) {
fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
return 1;