diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 416fc4c0f..f921b7845 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -8,6 +8,9 @@ #include #include +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + struct seq_draft { bool active = false; bool drafting = false; @@ -65,20 +68,27 @@ int main(int argc, char ** argv) { std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); { - int n_vocab_tgt = llama_n_vocab(model_tgt); - if (n_vocab_tgt != llama_n_vocab(model_dft)) { - fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); - fprintf(stderr, "target vocab size %d does not match draft vocab size %d\n", - n_vocab_tgt, llama_n_vocab(model_dft)); + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__); + fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return 1; } - for (int i = 0; i < n_vocab_tgt; i++) { + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_token_get_text(model_tgt, i); const char * token_text_dft = llama_token_get_text(model_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); - fprintf(stderr, "token %d content differs\n", i); + fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i, + llama_token_to_piece(ctx_tgt, i).c_str(), + llama_token_to_piece(ctx_dft, i).c_str()); return 1; } }