sampling : one sequence per sampling context

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-12 20:35:01 +03:00
parent 370359e5ba
commit 5261aee8d8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 28 additions and 86 deletions

View file

@ -1,50 +1,14 @@
#include "sampling.h"
llama_sampling_context::~llama_sampling_context() {
for (auto & it : sequence_contexts) {
if (it.second.grammar != NULL) {
llama_grammar_free(it.second.grammar);
it.second.grammar = NULL;
}
}
}
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar) {
llama_sampling_context result;
llama_sampling_context result;
result.params = params.sampling_params;
result.grammar = grammar;
return result;
}
result.params = params.sampling_params;
result.grammar = grammar;
// Note: Creates the context if it doesn't exist, so this always return something.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it != ctx_sampling.sequence_contexts.end()) {
return it->second;
}
llama_sampler_sequence_context new_ctx = {
2.0f * ctx_sampling.params.mirostat_tau,
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
};
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
}
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it == ctx_sampling.sequence_contexts.end()) return false;
if (it->second.grammar != NULL) {
llama_grammar_free(it->second.grammar);
it->second.grammar = NULL;
}
ctx_sampling.sequence_contexts.erase(it);
return true;
return result;
}
llama_token llama_sampling_sample(
@ -53,8 +17,7 @@ llama_token llama_sampling_sample(
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx,
llama_seq_id seq) {
const int idx) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -115,10 +78,8 @@ llama_token llama_sampling_sample(
}
}
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
if (ctx_seq.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
if (ctx_sampling.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
}
if (temp <= 0) {
@ -128,10 +89,10 @@ llama_token llama_sampling_sample(
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
} else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
@ -158,8 +119,8 @@ llama_token llama_sampling_sample(
}
}
if (ctx_seq.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
if (ctx_sampling.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
}
return id;

View file

@ -34,27 +34,14 @@ typedef struct llama_sampling_params {
} llama_sampling_params;
// per-sequence sampler context
typedef struct llama_sampler_sequence_context {
float mirostat_mu; // mirostat sampler state
llama_grammar * grammar;
} llama_sampler_sequence_context;
// general sampler context
typedef struct llama_sampling_context {
~llama_sampling_context();
// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_context instances
// parameters that will be used for sampling
llama_sampling_params params;
// map of sequence ids to sampler contexts
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
// mirostat sampler state
float mirostat_mu;
// when non-NULL, new instances of llama_sampler_sequence_context
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar;
} llama_sampling_context;
@ -65,13 +52,6 @@ llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);
// Fetches the sampler context for the specified sequence id (defaults to 0).
// If the context for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the ctx_sampling argument.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
// Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
@ -104,5 +84,4 @@ llama_token llama_sampling_sample(
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx = 0,
llama_seq_id seq = 0);
const int idx = 0);