speculative: Ensure draft and target model vocab matches
This commit is contained in:
parent
34b2a5e1ee
commit
a0897651a2
1 changed files with 22 additions and 1 deletions
|
@ -64,6 +64,26 @@ int main(int argc, char ** argv) {
|
|||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||
|
||||
{
|
||||
int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||
if (n_vocab_tgt != llama_n_vocab(model_dft)) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "target vocab size %d does not match draft vocab size %d\n",
|
||||
n_vocab_tgt, llama_n_vocab(model_dft));
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_vocab_tgt; i++) {
|
||||
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
||||
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "token %d content differs\n", i);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
||||
|
@ -227,6 +247,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode (ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
|
@ -370,7 +391,7 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
||||
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue