diff --git a/common/sampling.cpp b/common/sampling.cpp index 05751d91b..78cb03885 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,7 +1,7 @@ #include "sampling.h" -llama_sampling_state::~llama_sampling_state() { - for (auto & it : sequence_states) { +llama_sampling_context::~llama_sampling_context() { + for (auto & it : sequence_contexts) { if (it.second.grammar != NULL) { llama_grammar_free(it.second.grammar); it.second.grammar = NULL; @@ -9,48 +9,48 @@ llama_sampling_state::~llama_sampling_state() { } } -llama_sampling_state llama_sampling_state_init( +llama_sampling_context llama_sampling_context_init( const struct gpt_params & params, llama_grammar * grammar) { - llama_sampling_state result; + llama_sampling_context result; result.params = params.sampling_params; result.grammar = grammar; return result; } -// Note: Creates the state if it doesn't exist, so this always return something. -llama_sampler_sequence_state & llama_sampling_get_sequence_state( - llama_sampling_state & state, - const llama_seq_id seq) { - const auto it = state.sequence_states.find(seq); - if (it != state.sequence_states.end()) { +// Note: Creates the context if it doesn't exist, so this always return something. +llama_sampler_sequence_context & llama_sampling_get_sequence_context( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq) { + const auto it = sampling_ctx.sequence_contexts.find(seq); + if (it != sampling_ctx.sequence_contexts.end()) { return it->second; } - llama_sampler_sequence_state new_state = { - 2.0f * state.params.mirostat_tau, - state.grammar != NULL ? llama_grammar_copy(state.grammar) : NULL, + llama_sampler_sequence_context new_ctx = { + 2.0f * sampling_ctx.params.mirostat_tau, + sampling_ctx.grammar != NULL ? llama_grammar_copy(sampling_ctx.grammar) : NULL, }; - return state.sequence_states.insert({seq, new_state}).first->second; + return sampling_ctx.sequence_contexts.insert({seq, new_ctx}).first->second; } -bool llama_sampling_state_reset( - llama_sampling_state & state, - const llama_seq_id seq) { - const auto it = state.sequence_states.find(seq); - if (it == state.sequence_states.end()) return false; +bool llama_sampling_context_reset( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq) { + const auto it = sampling_ctx.sequence_contexts.find(seq); + if (it == sampling_ctx.sequence_contexts.end()) return false; if (it->second.grammar != NULL) { llama_grammar_free(it->second.grammar); it->second.grammar = NULL; } - state.sequence_states.erase(it); + sampling_ctx.sequence_contexts.erase(it); return true; } -llama_token llama_sample_token( +llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_state & state, + struct llama_sampling_context & sampling_ctx, const std::vector & last_tokens, std::vector & candidates, const int idx, @@ -58,7 +58,7 @@ llama_token llama_sample_token( const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const llama_sampling_params & params = state.params; + const llama_sampling_params & params = sampling_ctx.params; const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; @@ -115,10 +115,10 @@ llama_token llama_sample_token( } } - llama_sampler_sequence_state & seq_state = llama_sampling_get_sequence_state(state, seq); + llama_sampler_sequence_context & seq_ctx = llama_sampling_get_sequence_context(sampling_ctx, seq); - if (seq_state.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, seq_state.grammar); + if (seq_ctx.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, seq_ctx.grammar); } if (temp <= 0) { @@ -128,10 +128,10 @@ llama_token llama_sample_token( if (mirostat == 1) { const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_state.mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_ctx.mirostat_mu); } else if (mirostat == 2) { llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_state.mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_ctx.mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); @@ -158,8 +158,8 @@ llama_token llama_sample_token( } } - if (seq_state.grammar != NULL) { - llama_grammar_accept_token(ctx, seq_state.grammar, id); + if (seq_ctx.grammar != NULL) { + llama_grammar_accept_token(ctx, seq_ctx.grammar, id); } return id; diff --git a/common/sampling.h b/common/sampling.h index 48702e2dd..96891640f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -34,64 +34,62 @@ typedef struct llama_sampling_params { } llama_sampling_params; -// per-sequence sampler state -typedef struct llama_sampler_sequence_state { +// per-sequence sampler context +typedef struct llama_sampler_sequence_context { float mirostat_mu; // mirostat sampler state llama_grammar * grammar; -} llama_sampler_sequence_state; +} llama_sampler_sequence_context; -// general sampler state -typedef struct llama_sampling_state { - ~llama_sampling_state(); +// general sampler context +typedef struct llama_sampling_context { + ~llama_sampling_context(); // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_state instances + // new llama_sampler_sequence_context instances llama_sampling_params params; - // map of sequence ids to sampler states - std::unordered_map sequence_states; + // map of sequence ids to sampler contexts + std::unordered_map sequence_contexts; - // when non-NULL, new instances of llama_sampler_sequence_state + // when non-NULL, new instances of llama_sampler_sequence_context // will get a copy of the grammar here // note: only the pointer is stored here, it is not a copy of // the grammar and shouldn't be freed llama_grammar * grammar; -} llama_sampling_state; +} llama_sampling_context; #include "common.h" -// Create a new sampling state instance. -llama_sampling_state llama_sampling_state_init( +// Create a new sampling context instance. +llama_sampling_context llama_sampling_context_init( const struct gpt_params & params, llama_grammar * grammar = NULL); -// Fetches the sampler state for the specified sequence id (defaults to 0). -// If the state for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the state argument. -llama_sampler_sequence_state & llama_sampling_get_sequence_state( - llama_sampling_state & state, - const llama_seq_id seq = 0); +// Fetches the sampler context for the specified sequence id (defaults to 0). +// If the context for that sequence id doesn't already exist, it will be created with +// default values based on the parameters in the sampling_ctx argument. +llama_sampler_sequence_context & llama_sampling_get_sequence_context( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq = 0); -// Reset the sampler states for the supplied sequence id (defaults to 0). +// Reset the sampler context for the supplied sequence id (defaults to 0). // This is necessary to reuse a sequence id or free memory used by sequences // that are no longer required. -bool llama_sampling_state_reset( - llama_sampling_state & state, - const llama_seq_id seq = 0); +bool llama_sampling_context_reset( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq = 0); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function -// Note: When using multiple sequences, it is the caller's responsibility to delete -// the item in params.sampler_state when a sequence ends and samplers that rely on -// state are being used. +// Note: When using multiple sequences, it is the caller's responsibility to call +// llama_sampling_context_reset when a sequence ends // // required: -// - ctx: context to use for sampling -// - params: sampling parameters +// - ctx: context to use for sampling +// - sampling_ctx: sampling-specific context // // optional: // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty // - idx: sample from llama_get_logits_ith(ctx, idx) // - seq: sequence id to associate sampler state with @@ -100,10 +98,10 @@ bool llama_sampling_state_reset( // - token: sampled token // - candidates: vector of candidate tokens // -llama_token llama_sample_token( +llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_state & state, + struct llama_sampling_context & sampling_ctx, const std::vector & last_tokens, std::vector & candidates, const int idx = 0, diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index cde822678..cd9e6b142 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -423,7 +423,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar); + llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -542,7 +542,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sample_token(ctx, ctx_guidance, sampling_state, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 89abfd55b..eaa99459a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -470,7 +470,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar); + llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -627,7 +627,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sample_token(ctx, ctx_guidance, sampling_state, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index cb4c4ba96..50025a71c 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_sampling_state sampling_state = llama_sampling_state_init(params, NULL); + llama_sampling_context sampling_context = llama_sampling_context_init(params, NULL); // load the prompts from an external file if there are any if (params.prompt.empty()) { @@ -341,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, sampling_state, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); + const llama_token id = llama_sampling_sample(ctx, NULL, sampling_context, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -386,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - llama_sampling_state_reset(sampling_state, client.seq_id); + llama_sampling_context_reset(sampling_context, client.seq_id); client.seq_id = -1; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 08bdaa189..50e344a78 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,7 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; - llama_sampling_state sampling_state; + llama_sampling_context sampling_context; int n_ctx; grammar_parser::parse_state parsed_grammar; @@ -255,7 +255,7 @@ struct llama_server_context if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; - sampling_state = llama_sampling_state_init(params, NULL); + sampling_context = llama_sampling_context_init(params, NULL); } } @@ -341,7 +341,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - sampling_state = llama_sampling_state_init(params, grammar); + sampling_context = llama_sampling_context_init(params, grammar); return true; } @@ -542,7 +542,7 @@ struct llama_server_context std::vector candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sample_token(ctx, NULL, sampling_state, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx, NULL, sampling_context, last_n_tokens, candidates); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -1210,7 +1210,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.sampling_state = llama_sampling_state_init(llama.params, llama.grammar); + llama.sampling_context = llama_sampling_context_init(llama.params, llama.grammar); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6a07c1c17..d5b19cf56 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar_tgt); + llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar_tgt); const auto t_dec_start = ggml_time_us(); @@ -136,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sample_token(ctx_tgt, NULL, sampling_state, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, sampling_context, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -215,9 +215,9 @@ int main(int argc, char ** argv) { } // Note: Hardcoded to sequence id 0, if this ever supports parallel generation // that will need to change. - auto it = sampling_state.sequence_states.find(0); - GGML_ASSERT(it != sampling_state.sequence_states.end()); - // This is necessary because each sequence id in sequence_states + auto it = sampling_context.sequence_contexts.find(0); + GGML_ASSERT(it != sampling_context.sequence_contexts.end()); + // This is necessary because each sequence id in sequence_contexts // uses a copy of the original grammar. grammar_dft = llama_grammar_copy(it->second.grammar);