Check the full vocab for grammar only if necessary
This commit is contained in:
parent
33e171d1e9
commit
2e3b4f6237
4 changed files with 38 additions and 11 deletions
|
@ -100,10 +100,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|||
}
|
||||
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx) {
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx,
|
||||
bool is_resampling) { // Add a parameter to indicate if we are resampling
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
@ -128,8 +129,12 @@ llama_token llama_sampling_sample(
|
|||
|
||||
llama_token id = 0;
|
||||
|
||||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
// Make a copy of the original logits before any modifications
|
||||
std::vector<float> original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
||||
|
||||
// apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
|
@ -165,7 +170,8 @@ llama_token llama_sampling_sample(
|
|||
}
|
||||
}
|
||||
|
||||
if (ctx_sampling->grammar != NULL) {
|
||||
// If we are in the resampling phase, apply grammar checks before sampling logic
|
||||
if (is_resampling && ctx_sampling->grammar != NULL) {
|
||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||
}
|
||||
|
||||
|
@ -212,6 +218,26 @@ llama_token llama_sampling_sample(
|
|||
}
|
||||
}
|
||||
|
||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||
// Create an array with a single token data element for the sampled id
|
||||
llama_token_data single_token_data = {id, logits[id], 0.0f};
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||
|
||||
// Apply grammar constraints to the single token
|
||||
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
|
||||
|
||||
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
|
||||
// If the token is not valid according to the grammar, perform resampling
|
||||
if (!is_valid) {
|
||||
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||
|
||||
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
|
||||
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
||||
}
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
|
|
|
@ -98,10 +98,11 @@ std::string llama_sampling_print(const llama_sampling_params & params);
|
|||
// - candidates: vector of candidate tokens
|
||||
//
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx = 0);
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx,
|
||||
bool is_resampling = false); // Add the new parameter with default value
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
|
|
|
@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
|
||||
|
|
|
@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue