From 4a2e5e0dc51d6d35d372e7f666046c1317642d46 Mon Sep 17 00:00:00 2001 From: "Gilad S." <7817232+giladgd@users.noreply.github.com> Date: Thu, 12 Sep 2024 01:46:29 +0300 Subject: [PATCH] feat: remove a sampler from a chain --- include/llama.h | 3 +++ src/llama-sampling.cpp | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/include/llama.h b/include/llama.h index 405af912c..fdebdf26a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1056,6 +1056,9 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed + LLAMA_API void llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + // available samplers: LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index fd1b7f919..58f76ef26 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -356,6 +356,17 @@ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chai return p->samplers[i]; } +void llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { + auto * p = (llama_sampler_chain *) chain->ctx; + + if (i < 0 || i >= (int32_t) p->samplers.size()) { + return; + } + + auto * result = p->samplers[i]; + p->samplers.erase(p->samplers.begin() + i); +} + int llama_sampler_chain_n(const struct llama_sampler * chain) { const auto * p = (const llama_sampler_chain *) chain->ctx;