Tolerate small differences when checking dft vs tgt vocab
This commit is contained in:
parent
a0897651a2
commit
41f5d2acdf
1 changed files with 17 additions and 7 deletions
|
@ -8,6 +8,9 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct seq_draft {
|
||||
bool active = false;
|
||||
bool drafting = false;
|
||||
|
@ -65,20 +68,27 @@ int main(int argc, char ** argv) {
|
|||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||
|
||||
{
|
||||
int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||
if (n_vocab_tgt != llama_n_vocab(model_dft)) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "target vocab size %d does not match draft vocab size %d\n",
|
||||
n_vocab_tgt, llama_n_vocab(model_dft));
|
||||
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||
const int n_vocab_dft = llama_n_vocab(model_dft);
|
||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||
? n_vocab_tgt - n_vocab_dft
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_vocab_tgt; i++) {
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
||||
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "token %d content differs\n", i);
|
||||
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
|
||||
llama_token_to_piece(ctx_tgt, i).c_str(),
|
||||
llama_token_to_piece(ctx_dft, i).c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue