Tolerate small differences when checking dft vs tgt vocab

This commit is contained in:
KerfuffleV2 2023-10-27 08:15:00 -06:00
parent a0897651a2
commit 41f5d2acdf

View file

@ -8,6 +8,9 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
bool active = false;
bool drafting = false;
@ -65,20 +68,27 @@ int main(int argc, char ** argv) {
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
{
int n_vocab_tgt = llama_n_vocab(model_tgt);
if (n_vocab_tgt != llama_n_vocab(model_dft)) {
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
fprintf(stderr, "target vocab size %d does not match draft vocab size %d\n",
n_vocab_tgt, llama_n_vocab(model_dft));
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}
for (int i = 0; i < n_vocab_tgt; i++) {
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
fprintf(stderr, "token %d content differs\n", i);
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
llama_token_to_piece(ctx_tgt, i).c_str(),
llama_token_to_piece(ctx_dft, i).c_str());
return 1;
}
}