Tolerate small differences when checking dft vs tgt vocab

This commit is contained in:
KerfuffleV2 2023-10-27 08:15:00 -06:00
parent a0897651a2
commit 41f5d2acdf

View file

@ -8,6 +8,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft { struct seq_draft {
bool active = false; bool active = false;
bool drafting = false; bool drafting = false;
@ -65,20 +68,27 @@ int main(int argc, char ** argv) {
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
{ {
int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_tgt = llama_n_vocab(model_tgt);
if (n_vocab_tgt != llama_n_vocab(model_dft)) { const int n_vocab_dft = llama_n_vocab(model_dft);
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); const int vocab_diff = n_vocab_tgt > n_vocab_dft
fprintf(stderr, "target vocab size %d does not match draft vocab size %d\n", ? n_vocab_tgt - n_vocab_dft
n_vocab_tgt, llama_n_vocab(model_dft)); : n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1; return 1;
} }
for (int i = 0; i < n_vocab_tgt; i++) { for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i); const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i); const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) { if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
fprintf(stderr, "token %d content differs\n", i); fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
llama_token_to_piece(ctx_tgt, i).c_str(),
llama_token_to_piece(ctx_dft, i).c_str());
return 1; return 1;
} }
} }