common : fix mirostat state when using multiple sequences (#3543)

* Fix mirostat state when using multiple sequences

* Fix mirostat by completely refactoring sampling!

* Try to fix zig build.

* Export function to fetch/create default sampler states

Code formatting cleanups and add some comments

Silence a warning about id not being used when logging is disabled

* Apply some renaming suggestions.

Fix comments that were out of sync with the pull.

* Use more consistant naming convention for sampling contexts
This commit is contained in:
Kerfuffle 2023-10-11 13:35:46 -06:00 committed by GitHub
parent 8c70a5ff25
commit 70c29da118
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 495 additions and 334 deletions

View file

@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
const auto t_dec_start = ggml_time_us();
while (true) {
@ -134,7 +136,7 @@ int main(int argc, char ** argv) {
while (true) {
// sample from the target model
llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
// remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin());
@ -211,7 +213,13 @@ int main(int argc, char ** argv) {
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
grammar_dft = llama_grammar_copy(grammar_tgt);
// Note: Hardcoded to sequence id 0, if this ever supports parallel generation
// that will need to change.
auto it = ctx_sampling.sequence_contexts.find(0);
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
// This is necessary because each sequence id in sequence_contexts
// uses a copy of the original grammar.
grammar_dft = llama_grammar_copy(it->second.grammar);
LOG("copied target grammar to draft grammar\n");
}