Check the full vocab for grammar only if necessary

This commit is contained in:
kalomaze 2023-12-03 04:23:14 -06:00
parent 33e171d1e9
commit 2e3b4f6237
4 changed files with 38 additions and 11 deletions

View file

@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@ -128,8 +129,12 @@ llama_token llama_sampling_sample(
llama_token id = 0;
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
// Make a copy of the original logits before any modifications
std::vector<float> original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
@ -165,7 +170,8 @@ llama_token llama_sampling_sample(
}
}
if (ctx_sampling->grammar != NULL) {
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
@ -212,6 +218,26 @@ llama_token llama_sampling_sample(
}
}
if (ctx_sampling->grammar != NULL && !is_resampling) {
// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
// Apply grammar constraints to the single token
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
// If the token is not valid according to the grammar, perform resampling
if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
}
}
return id;
}

View file

@ -101,7 +101,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
const int idx,
bool is_resampling = false); // Add the new parameter with default value
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,

View file

@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
llama_sampling_accept(ctx_sampling, ctx, id, true);

View file

@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
llama_sampling_accept(ctx_sampling, ctx, id, true);