added llama_sampling_rollback api

This commit is contained in:
l3utterfly 2024-01-31 01:02:34 +09:00
parent 8f8ddfcfad
commit 70074f6f10
2 changed files with 14 additions and 0 deletions

View file

@ -323,3 +323,13 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
} }
} }
void llama_sampling_rollback(
struct llama_sampling_context * ctx_sampling,
int rollback_num) {
if(rollback_num > ctx_sampling->prev.size()) {
rollback_num = ctx_sampling->prev.size();
}
ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end());
}

View file

@ -117,3 +117,7 @@ void llama_sampling_accept(
struct llama_context * ctx_main, struct llama_context * ctx_main,
llama_token id, llama_token id,
bool apply_grammar); bool apply_grammar);
void llama_sampling_rollback(
struct llama_sampling_context * ctx_sampling,
int rollback_num);