diff --git a/common/sampling.cpp b/common/sampling.cpp index be6409316..e9dfeb1e7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -37,10 +37,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { delete ctx; } -void llama_sampling_reset(llama_sampling_context * ctx) { +void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) { if (ctx->grammar != NULL) { llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; + ctx->grammar = nullptr; } if (!ctx->parsed_grammar.rules.empty()) { @@ -50,6 +50,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.data(), grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } +} + +void llama_sampling_reset(llama_sampling_context * ctx) { + llama_sampling_reset_grammar(ctx); std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); diff --git a/common/sampling.h b/common/sampling.h index f1df25890..f31b2d900 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -69,6 +69,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ void llama_sampling_free(struct llama_sampling_context * ctx); +// Reset the sampler grammar without resetting the context +void llama_sampling_reset_grammar(struct llama_sampling_context * ctx); + // Reset the sampler context // - clear prev tokens // - reset grammar