Check the full vocab for grammar only if necessary
This commit is contained in:
parent
33e171d1e9
commit
2e3b4f6237
4 changed files with 38 additions and 11 deletions
|
@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx,
|
||||||
|
bool is_resampling) { // Add a parameter to indicate if we are resampling
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
@ -128,8 +129,12 @@ llama_token llama_sampling_sample(
|
||||||
|
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
|
|
||||||
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
|
// Make a copy of the original logits before any modifications
|
||||||
|
std::vector<float> original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
||||||
|
|
||||||
// apply params.logit_bias map
|
// apply params.logit_bias map
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
|
@ -165,7 +170,8 @@ llama_token llama_sampling_sample(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL) {
|
// If we are in the resampling phase, apply grammar checks before sampling logic
|
||||||
|
if (is_resampling && ctx_sampling->grammar != NULL) {
|
||||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,6 +218,26 @@ llama_token llama_sampling_sample(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||||
|
// Create an array with a single token data element for the sampled id
|
||||||
|
llama_token_data single_token_data = {id, logits[id], 0.0f};
|
||||||
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||||
|
|
||||||
|
// Apply grammar constraints to the single token
|
||||||
|
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
|
||||||
|
|
||||||
|
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||||
|
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
|
|
||||||
|
// If the token is not valid according to the grammar, perform resampling
|
||||||
|
if (!is_valid) {
|
||||||
|
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||||
|
|
||||||
|
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
|
||||||
|
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,8 @@ llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
const int idx,
|
||||||
|
bool is_resampling = false); // Add the new parameter with default value
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
|
|
@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
|
||||||
LOG("saved session to %s\n", path_session.c_str());
|
LOG("saved session to %s\n", path_session.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue