speculative: Ensure draft and target model vocab matches
This commit is contained in:
parent
34b2a5e1ee
commit
a0897651a2
1 changed files with 22 additions and 1 deletions
|
@ -64,6 +64,26 @@ int main(int argc, char ** argv) {
|
||||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
|
{
|
||||||
|
int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||||
|
if (n_vocab_tgt != llama_n_vocab(model_dft)) {
|
||||||
|
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||||
|
fprintf(stderr, "target vocab size %d does not match draft vocab size %d\n",
|
||||||
|
n_vocab_tgt, llama_n_vocab(model_dft));
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_vocab_tgt; i++) {
|
||||||
|
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
||||||
|
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
||||||
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||||
|
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||||
|
fprintf(stderr, "token %d content differs\n", i);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<llama_token> inp;
|
std::vector<llama_token> inp;
|
||||||
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
||||||
|
@ -227,6 +247,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
|
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||||
llama_decode (ctx_dft, batch_dft);
|
llama_decode (ctx_dft, batch_dft);
|
||||||
|
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
@ -370,7 +391,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||||
llama_decode(ctx_tgt, batch_tgt);
|
llama_decode(ctx_tgt, batch_tgt);
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue