diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 74d883410..86824b400 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -167,6 +167,9 @@ int main(int argc, char ** argv) { std::vector drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar + if (params.sparams.temp == 0) { + params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + } for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams);