From 2e3b4f62378082680f47b58080f771439a31075f Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 04:23:14 -0600 Subject: [PATCH] Check the full vocab for grammar only if necessary --- common/sampling.cpp | 36 +++++++++++++++++++++++++++++++----- common/sampling.h | 9 +++++---- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 2 +- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2..78092611b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -100,10 +100,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx) { + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx, + bool is_resampling) { // Add a parameter to indicate if we are resampling const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -128,7 +129,11 @@ llama_token llama_sampling_sample( llama_token id = 0; + // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + + // Make a copy of the original logits before any modifications + std::vector original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -165,7 +170,8 @@ llama_token llama_sampling_sample( } } - if (ctx_sampling->grammar != NULL) { + // If we are in the resampling phase, apply grammar checks before sampling logic + if (is_resampling && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } @@ -212,6 +218,26 @@ llama_token llama_sampling_sample( } } + if (ctx_sampling->grammar != NULL && !is_resampling) { + // Create an array with a single token data element for the sampled id + llama_token_data single_token_data = {id, logits[id], 0.0f}; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + + // Apply grammar constraints to the single token + llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar); + + // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + + // If the token is not valid according to the grammar, perform resampling + if (!is_valid) { + LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); + + // Recursively call llama_sampling_sample to resample with the grammar checks applied first + return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + } + } + return id; } diff --git a/common/sampling.h b/common/sampling.h index 7c9b8dcf2..5c387fb6f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -98,10 +98,11 @@ std::string llama_sampling_print(const llama_sampling_params & params); // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = 0); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx, + bool is_resampling = false); // Add the new parameter with default value void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 4a7827876..c4a38e5e2 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -527,7 +527,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c5cdfbf21..c67493dc6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -630,7 +630,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); llama_sampling_accept(ctx_sampling, ctx, id, true);