diff --git a/common/sampling.cpp b/common/sampling.cpp index e8675a8c0..be6409316 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -323,3 +323,13 @@ void llama_sampling_accept( llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } + +void llama_sampling_rollback( + struct llama_sampling_context * ctx_sampling, + int rollback_num) { + if(rollback_num > ctx_sampling->prev.size()) { + rollback_num = ctx_sampling->prev.size(); + } + + ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end()); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 88899c094..f1df25890 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -117,3 +117,7 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +void llama_sampling_rollback( + struct llama_sampling_context * ctx_sampling, + int rollback_num);