Split sampling into the helper function (?)

And also revert the changes made to the header
This commit is contained in:
kalomaze 2023-12-07 14:58:48 -06:00
parent f5f9d9620b
commit 115a9218eb
4 changed files with 18 additions and 12 deletions

View file

@ -99,7 +99,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
return std::string(result); return std::string(result);
} }
llama_token llama_sampling_sample( static llama_token llama_sampling_sample_impl(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
@ -241,14 +241,22 @@ llama_token llama_sampling_sample(
// Restore logits from the copy // Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits); std::copy(original_logits.begin(), original_logits.end(), logits);
// Recursively call llama_sampling_sample to resample with the grammar checks applied first return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
} }
} }
return id; return id;
} }
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
// Call the implementation function with is_resampling set to false by default
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,

View file

@ -92,18 +92,16 @@ std::string llama_sampling_print(const llama_sampling_params & params);
// optional: // optional:
// - ctx_cfg: context to use for classifier-free guidance // - ctx_cfg: context to use for classifier-free guidance
// - idx: sample from llama_get_logits_ith(ctx, idx) // - idx: sample from llama_get_logits_ith(ctx, idx)
// - is_resampling: determines whether or not this is a repeated sampling operation due to the ID not matching the grammar
// //
// returns: // returns:
// - token: sampled token // - token: sampled token
// - candidates: vector of candidate tokens // - candidates: vector of candidate tokens
// //
llama_token llama_sampling_sample( llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx, int idx = 0);
bool is_resampling = false);
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,

View file

@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);

View file

@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str()); LOG("saved session to %s\n", path_session.c_str());
} }
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);