From 01bef0290020936baf90e9aa2c65c831894a77ea Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 8 Oct 2023 11:59:07 -0600 Subject: [PATCH] Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled --- common/sampling.cpp | 18 ++++++++++++------ common/sampling.h | 27 ++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 5e8ad1db4..05751d91b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -9,7 +9,9 @@ llama_sampling_state::~llama_sampling_state() { } } -llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar) { +llama_sampling_state llama_sampling_state_init( + const struct gpt_params & params, + llama_grammar * grammar) { llama_sampling_state result; result.params = params.sampling_params; @@ -17,8 +19,10 @@ llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, return result; } -// Creates the state if it doesn't exist, so this always return something. -static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling_state & state, const llama_seq_id seq) { +// Note: Creates the state if it doesn't exist, so this always return something. +llama_sampler_sequence_state & llama_sampling_get_sequence_state( + llama_sampling_state & state, + const llama_seq_id seq) { const auto it = state.sequence_states.find(seq); if (it != state.sequence_states.end()) { return it->second; @@ -30,7 +34,9 @@ static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling return state.sequence_states.insert({seq, new_state}).first->second; } -bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq) { +bool llama_sampling_state_reset( + llama_sampling_state & state, + const llama_seq_id seq) { const auto it = state.sequence_states.find(seq); if (it == state.sequence_states.end()) return false; if (it->second.grammar != NULL) { @@ -109,7 +115,7 @@ llama_token llama_sample_token( } } - llama_sampler_sequence_state & seq_state = sampling_get_sequence_state(state, seq); + llama_sampler_sequence_state & seq_state = llama_sampling_get_sequence_state(state, seq); if (seq_state.grammar != NULL) { llama_sample_grammar(ctx, &cur_p, seq_state.grammar); @@ -141,6 +147,7 @@ llama_token llama_sample_token( for (int i = 0; i < n_top; i++) { const llama_token id = cur_p.data[i].id; + (void)id; // To avoid a warning that id is unused when logging is disabled. LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); } } @@ -150,7 +157,6 @@ llama_token llama_sample_token( LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); } } - // printf("`%d`", candidates_p.size); if (seq_state.grammar != NULL) { llama_grammar_accept_token(ctx, seq_state.grammar, id); diff --git a/common/sampling.h b/common/sampling.h index 6c5f8749e..48702e2dd 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -44,19 +44,40 @@ typedef struct llama_sampler_sequence_state { typedef struct llama_sampling_state { ~llama_sampling_state(); + // parameters that will be used for sampling and when creating + // new llama_sampler_sequence_state instances llama_sampling_params params; + // map of sequence ids to sampler states std::unordered_map sequence_states; + // when non-NULL, new instances of llama_sampler_sequence_state + // will get a copy of the grammar here + // note: only the pointer is stored here, it is not a copy of + // the grammar and shouldn't be freed llama_grammar * grammar; } llama_sampling_state; #include "common.h" // Create a new sampling state instance. -llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar); +llama_sampling_state llama_sampling_state_init( + const struct gpt_params & params, + llama_grammar * grammar = NULL); -bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq = 0); +// Fetches the sampler state for the specified sequence id (defaults to 0). +// If the state for that sequence id doesn't already exist, it will be created with +// default values based on the parameters in the state argument. +llama_sampler_sequence_state & llama_sampling_get_sequence_state( + llama_sampling_state & state, + const llama_seq_id seq = 0); + +// Reset the sampler states for the supplied sequence id (defaults to 0). +// This is necessary to reuse a sequence id or free memory used by sequences +// that are no longer required. +bool llama_sampling_state_reset( + llama_sampling_state & state, + const llama_seq_id seq = 0); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function @@ -73,7 +94,7 @@ bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty // - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with (currently only used by mirostat) +// - seq: sequence id to associate sampler state with // // returns: // - token: sampled token