diff --git a/common/sampling.cpp b/common/sampling.cpp index 78cb03885..8ce419459 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -21,36 +21,36 @@ llama_sampling_context llama_sampling_context_init( // Note: Creates the context if it doesn't exist, so this always return something. llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq) { - const auto it = sampling_ctx.sequence_contexts.find(seq); - if (it != sampling_ctx.sequence_contexts.end()) { + const auto it = ctx_sampling.sequence_contexts.find(seq); + if (it != ctx_sampling.sequence_contexts.end()) { return it->second; } llama_sampler_sequence_context new_ctx = { - 2.0f * sampling_ctx.params.mirostat_tau, - sampling_ctx.grammar != NULL ? llama_grammar_copy(sampling_ctx.grammar) : NULL, + 2.0f * ctx_sampling.params.mirostat_tau, + ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL, }; - return sampling_ctx.sequence_contexts.insert({seq, new_ctx}).first->second; + return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second; } bool llama_sampling_context_reset( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq) { - const auto it = sampling_ctx.sequence_contexts.find(seq); - if (it == sampling_ctx.sequence_contexts.end()) return false; + const auto it = ctx_sampling.sequence_contexts.find(seq); + if (it == ctx_sampling.sequence_contexts.end()) return false; if (it->second.grammar != NULL) { llama_grammar_free(it->second.grammar); it->second.grammar = NULL; } - sampling_ctx.sequence_contexts.erase(it); + ctx_sampling.sequence_contexts.erase(it); return true; } llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_context & sampling_ctx, + struct llama_sampling_context & ctx_sampling, const std::vector & last_tokens, std::vector & candidates, const int idx, @@ -58,7 +58,7 @@ llama_token llama_sampling_sample( const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const llama_sampling_params & params = sampling_ctx.params; + const llama_sampling_params & params = ctx_sampling.params; const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; @@ -115,10 +115,10 @@ llama_token llama_sampling_sample( } } - llama_sampler_sequence_context & seq_ctx = llama_sampling_get_sequence_context(sampling_ctx, seq); + llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); - if (seq_ctx.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, seq_ctx.grammar); + if (ctx_seq.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar); } if (temp <= 0) { @@ -128,10 +128,10 @@ llama_token llama_sampling_sample( if (mirostat == 1) { const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_ctx.mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); } else if (mirostat == 2) { llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_ctx.mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); @@ -158,8 +158,8 @@ llama_token llama_sampling_sample( } } - if (seq_ctx.grammar != NULL) { - llama_grammar_accept_token(ctx, seq_ctx.grammar, id); + if (ctx_seq.grammar != NULL) { + llama_grammar_accept_token(ctx, ctx_seq.grammar, id); } return id; diff --git a/common/sampling.h b/common/sampling.h index 96891640f..0aab5d03c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -67,16 +67,16 @@ llama_sampling_context llama_sampling_context_init( // Fetches the sampler context for the specified sequence id (defaults to 0). // If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the sampling_ctx argument. +// default values based on the parameters in the ctx_sampling argument. llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq = 0); // Reset the sampler context for the supplied sequence id (defaults to 0). // This is necessary to reuse a sequence id or free memory used by sequences // that are no longer required. bool llama_sampling_context_reset( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq = 0); // this is a common sampling function used across the examples for convenience @@ -86,7 +86,7 @@ bool llama_sampling_context_reset( // // required: // - ctx: context to use for sampling -// - sampling_ctx: sampling-specific context +// - ctx_sampling: sampling-specific context // // optional: // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL @@ -101,7 +101,7 @@ bool llama_sampling_context_reset( llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_context & sampling_ctx, + struct llama_sampling_context & ctx_sampling, const std::vector & last_tokens, std::vector & candidates, const int idx = 0, diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index cd9e6b142..525a3beee 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -423,7 +423,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -542,7 +542,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index eaa99459a..b39a67d97 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -470,7 +470,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -627,7 +627,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 50025a71c..cdb198a7c 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_sampling_context sampling_context = llama_sampling_context_init(params, NULL); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL); // load the prompts from an external file if there are any if (params.prompt.empty()) { @@ -341,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(ctx, NULL, sampling_context, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); + const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -386,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - llama_sampling_context_reset(sampling_context, client.seq_id); + llama_sampling_context_reset(ctx_sampling, client.seq_id); client.seq_id = -1; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50e344a78..e906a17bf 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,7 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; - llama_sampling_context sampling_context; + llama_sampling_context ctx_sampling; int n_ctx; grammar_parser::parse_state parsed_grammar; @@ -255,7 +255,7 @@ struct llama_server_context if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; - sampling_context = llama_sampling_context_init(params, NULL); + ctx_sampling = llama_sampling_context_init(params, NULL); } } @@ -341,7 +341,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - sampling_context = llama_sampling_context_init(params, grammar); + ctx_sampling = llama_sampling_context_init(params, grammar); return true; } @@ -542,7 +542,7 @@ struct llama_server_context std::vector candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sampling_sample(ctx, NULL, sampling_context, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -1210,7 +1210,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.sampling_context = llama_sampling_context_init(llama.params, llama.grammar); + llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d5b19cf56..018dbf9a2 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar_tgt); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt); const auto t_dec_start = ggml_time_us(); @@ -136,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sampling_sample(ctx_tgt, NULL, sampling_context, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -215,8 +215,8 @@ int main(int argc, char ** argv) { } // Note: Hardcoded to sequence id 0, if this ever supports parallel generation // that will need to change. - auto it = sampling_context.sequence_contexts.find(0); - GGML_ASSERT(it != sampling_context.sequence_contexts.end()); + auto it = ctx_sampling.sequence_contexts.find(0); + GGML_ASSERT(it != ctx_sampling.sequence_contexts.end()); // This is necessary because each sequence id in sequence_contexts // uses a copy of the original grammar. grammar_dft = llama_grammar_copy(it->second.grammar);