diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 73bec0b98..2b0f907a8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1127,11 +1127,16 @@ static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { delete (llama_sampler_k_shift *) smpl->ctx; } +static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + ctx->k_set = false; +} + static struct llama_sampler_i llama_sampler_k_shift_i = { /* .name = */ llama_sampler_k_shift_name, /* .accept = */ nullptr, /* .apply = */ llama_sampler_k_shift_apply, - /* .reset = */ nullptr, + /* .reset = */ llama_sampler_k_shift_reset, /* .clone = */ llama_sampler_k_shift_clone, /* .free = */ llama_sampler_k_shift_free, };