add eos_id_list to llama.cpp

This commit is contained in:
toyer 2024-06-24 12:27:02 +00:00
parent 4b65b648ce
commit 3a4d5790bf
13 changed files with 122 additions and 55 deletions

View file

@ -240,7 +240,11 @@ int64_t get_example_targets_batch(
ggml_set_f32(target_probs, 0.0f);
llama_token bos = llama_token_bos(llama_get_model(lctx));
llama_token eos = llama_token_eos(llama_get_model(lctx));
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
llama_token eos = eos_ptr[0];
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);