diff --git a/include/llama.h b/include/llama.h index 405af912c..fdebdf26a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1056,6 +1056,9 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed + LLAMA_API void llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + // available samplers: LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index fd1b7f919..58f76ef26 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -356,6 +356,17 @@ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chai return p->samplers[i]; } +void llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { + auto * p = (llama_sampler_chain *) chain->ctx; + + if (i < 0 || i >= (int32_t) p->samplers.size()) { + return; + } + + auto * result = p->samplers[i]; + p->samplers.erase(p->samplers.begin() + i); +} + int llama_sampler_chain_n(const struct llama_sampler * chain) { const auto * p = (const llama_sampler_chain *) chain->ctx;