sample from residual distribution on draft accept failure

This commit is contained in:
Minsoo Cheong 2024-02-22 13:50:30 +09:00
parent c1bad4a549
commit a9335a5c2a

View file

@ -298,12 +298,12 @@ int main(int argc, char ** argv) {
if (!accept) {
// all drafted tokens were rejected
// sample from the target model
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
} else {
// greedy verification