diff --git a/common/sampling.cpp b/common/sampling.cpp index f5ac66512..6344c29da 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,7 +99,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { return std::string(result); } -llama_token llama_sampling_sample( +static llama_token llama_sampling_sample_impl( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, @@ -241,14 +241,22 @@ llama_token llama_sampling_sample( // Restore logits from the copy std::copy(original_logits.begin(), original_logits.end(), logits); - // Recursively call llama_sampling_sample to resample with the grammar checks applied first - return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling } } return id; } +llama_token llama_sampling_sample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + // Call the implementation function with is_resampling set to false by default + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 1c1297f64..4a8c522b6 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -92,21 +92,19 @@ std::string llama_sampling_print(const llama_sampling_params & params); // optional: // - ctx_cfg: context to use for classifier-free guidance // - idx: sample from llama_get_logits_ith(ctx, idx) -// - is_resampling: determines whether or not this is a repeated sampling operation due to the ID not matching the grammar // // returns: // - token: sampled token // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool is_resampling = false); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, llama_token id, - bool apply_grammar); + bool apply_grammar); \ No newline at end of file diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index c4a38e5e2..4a7827876 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -527,7 +527,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c67493dc6..c5cdfbf21 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -630,7 +630,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); llama_sampling_accept(ctx_sampling, ctx, id, true);