Split sampling into the helper function (?)
And also revert the changes made to the header
This commit is contained in:
parent
f5f9d9620b
commit
115a9218eb
4 changed files with 18 additions and 12 deletions
|
@ -99,7 +99,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
static llama_token llama_sampling_sample_impl(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
|
@ -241,14 +241,22 @@ llama_token llama_sampling_sample(
|
||||||
// Restore logits from the copy
|
// Restore logits from the copy
|
||||||
std::copy(original_logits.begin(), original_logits.end(), logits);
|
std::copy(original_logits.begin(), original_logits.end(), logits);
|
||||||
|
|
||||||
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
||||||
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
const int idx) {
|
||||||
|
// Call the implementation function with is_resampling set to false by default
|
||||||
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
|
|
@ -92,7 +92,6 @@ std::string llama_sampling_print(const llama_sampling_params & params);
|
||||||
// optional:
|
// optional:
|
||||||
// - ctx_cfg: context to use for classifier-free guidance
|
// - ctx_cfg: context to use for classifier-free guidance
|
||||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||||
// - is_resampling: determines whether or not this is a repeated sampling operation due to the ID not matching the grammar
|
|
||||||
//
|
//
|
||||||
// returns:
|
// returns:
|
||||||
// - token: sampled token
|
// - token: sampled token
|
||||||
|
@ -102,8 +101,7 @@ llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx,
|
int idx = 0);
|
||||||
bool is_resampling = false);
|
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
|
|
@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
|
||||||
LOG("saved session to %s\n", path_session.c_str());
|
LOG("saved session to %s\n", path_session.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue