speculative : add infill mode

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-26 11:14:17 +02:00
parent 0eb4e12bee
commit b83cae088c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 26 additions and 22 deletions

View file

@ -11,7 +11,9 @@
struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
struct common_sampler * smpl_infill;
llama_batch batch;
llama_tokens prompt;
@ -20,14 +22,26 @@ struct common_speculative {
struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .smpl_infill = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};
// TODO: optimize or pass from outside?
#if 0
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
{
common_params_sampling params;
params.no_perf = false;
@ -41,28 +55,15 @@ struct common_speculative * common_speculative_init(
COMMON_SAMPLER_TYPE_INFILL,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
}
#else
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#endif
return result;
}
void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
common_sampler_free(spec->smpl_infill);
llama_batch_free(spec->batch);
@ -133,7 +134,7 @@ llama_tokens common_speculative_gen_draft(
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
auto & prompt = spec->prompt;
int reuse_i = 0;