speculative : refactor and add a simpler example (#10362)

* speculative : refactor and add a simpler example

ggml-ci

* speculative : clean-up and add comments and TODOs [no ci]

* speculative : manage context in common_speculative

ggml-ci

* speculative : simplify

ggml-ci

* speculative : simplify (cont)

ggml-ci

* speculative : add --draft-min CLI arg

* speculative : minor fixup

* make : build fixes

* speculative : do not redraft previous drafts

ggml-ci

* speculative : fix the draft sampling

ggml-ci

* speculative : fix compile warning

* common : refactor args

ggml-ci

* common : change defaults [no ci]

* common : final touches

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-25 09:58:41 +02:00 committed by GitHub
parent cce5a90075
commit d9d54e498d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 1028 additions and 326 deletions

View file

@ -99,7 +99,7 @@ struct ring_buffer {
};
struct common_sampler {
common_sampler_params params;
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
@ -125,7 +125,7 @@ struct common_sampler {
}
};
std::string common_sampler_params::print() const {
std::string common_params_sampling::print() const {
char result[1024];
snprintf(result, sizeof(result),
@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
return std::string(result);
}
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf;
@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
}
return result;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}