separated sampling reset and sampling reset grammer only

This commit is contained in:
l3utterfly 2024-01-31 01:03:31 +09:00
parent 70074f6f10
commit d33f030896
2 changed files with 9 additions and 2 deletions

View file

@ -37,10 +37,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
delete ctx; delete ctx;
} }
void llama_sampling_reset(llama_sampling_context * ctx) { void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) {
if (ctx->grammar != NULL) { if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar); llama_grammar_free(ctx->grammar);
ctx->grammar = NULL; ctx->grammar = nullptr;
} }
if (!ctx->parsed_grammar.rules.empty()) { if (!ctx->parsed_grammar.rules.empty()) {
@ -50,6 +50,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
grammar_rules.data(), grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
} }
}
void llama_sampling_reset(llama_sampling_context * ctx) {
llama_sampling_reset_grammar(ctx);
std::fill(ctx->prev.begin(), ctx->prev.end(), 0); std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear(); ctx->cur.clear();

View file

@ -69,6 +69,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
void llama_sampling_free(struct llama_sampling_context * ctx); void llama_sampling_free(struct llama_sampling_context * ctx);
// Reset the sampler grammar without resetting the context
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx);
// Reset the sampler context // Reset the sampler context
// - clear prev tokens // - clear prev tokens
// - reset grammar // - reset grammar