BERT tokenizer fixes (#6498)

Key changes:
* BERT conversion: fix abuse of LlamaHfVocab, do not set BOS or EOS
* Nomic Embed conversion: pad vocab instead of slicing embedding tensor
* llama_tokenize: handle added special tokens like HF does
This commit is contained in:
Jared Van Bortel 2024-04-09 13:44:08 -04:00 committed by GitHub
parent c4a3a4ff47
commit 1b67731e18
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 221 additions and 194 deletions

View file

@ -76,6 +76,28 @@ int main(int argc, char ** argv) {
params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG("vocab_type tgt: %d\n", vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
LOG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return 1;
}
if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
) {
fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
return 1;
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
@ -105,20 +127,8 @@ int main(int argc, char ** argv) {
// Tokenize the prompt
const bool add_bos_tgt = llama_should_add_bos_token(model_tgt);
LOG("add_bos tgt: %d\n", add_bos_tgt);
const bool add_bos_dft = llama_should_add_bos_token(model_dft);
LOG("add_bos dft: %d\n", add_bos_dft);
if (add_bos_tgt != add_bos_dft) {
fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__);
fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt);
return 1;
}
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true);
inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
const int max_context_size = llama_n_ctx(ctx_tgt);
const int max_tokens_list_size = max_context_size - 4;