speculative : add infill mode
ggml-ci
This commit is contained in:
parent
0eb4e12bee
commit
b83cae088c
3 changed files with 26 additions and 22 deletions
|
@ -11,7 +11,9 @@
|
|||
|
||||
struct common_speculative {
|
||||
struct llama_context * ctx;
|
||||
|
||||
struct common_sampler * smpl;
|
||||
struct common_sampler * smpl_infill;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt;
|
||||
|
@ -20,14 +22,26 @@ struct common_speculative {
|
|||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_dft) {
|
||||
auto * result = new common_speculative {
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .smpl_infill = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
};
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
#if 0
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
@ -41,28 +55,15 @@ struct common_speculative * common_speculative_init(
|
|||
COMMON_SAMPLER_TYPE_INFILL,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#else
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec) {
|
||||
common_sampler_free(spec->smpl);
|
||||
common_sampler_free(spec->smpl_infill);
|
||||
|
||||
llama_batch_free(spec->batch);
|
||||
|
||||
|
@ -133,7 +134,7 @@ llama_tokens common_speculative_gen_draft(
|
|||
llama_token id_last) {
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx = spec->ctx;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
|
||||
auto & prompt = spec->prompt;
|
||||
|
||||
int reuse_i = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue