Fix mirostat state when using multiple sequences

This commit is contained in:
KerfuffleV2 2023-10-08 03:52:45 -06:00
parent db3abcc114
commit 0e6db6fec1
3 changed files with 31 additions and 10 deletions

View file

@ -940,10 +940,11 @@ llama_token llama_sample_token(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_context * ctx_guidance, struct llama_context * ctx_guidance,
struct llama_grammar * grammar, struct llama_grammar * grammar,
const struct gpt_params & params, struct gpt_params & params,
const std::vector<llama_token> & last_tokens, const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates, std::vector<llama_token_data> & candidates,
int idx) { const int idx,
llama_seq_id seq) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -1011,15 +1012,23 @@ llama_token llama_sample_token(
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p); id = llama_sample_token_greedy(ctx, &cur_p);
} else { } else {
float * mirostat_mu = NULL;
if (mirostat > 0) {
seq = std::max(0, seq); // Deal with people passing -1 or something.
auto mu_it = params.sampler_state.find(seq);
if (mu_it == params.sampler_state.end()) {
const llama_sampler_state new_state = { 2.0f * mirostat_tau };
mu_it = params.sampler_state.insert({seq, new_state}).first;
}
mirostat_mu = &mu_it->second.mirostat_mu;
}
if (mirostat == 1) { if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu);
} else if (mirostat == 2) { } else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu);
} else { } else {
// Temperature sampling // Temperature sampling
size_t min_keep = std::max(1, params.n_probs); size_t min_keep = std::max(1, params.n_probs);

View file

@ -33,6 +33,10 @@
// //
int32_t get_num_physical_cores(); int32_t get_num_physical_cores();
typedef struct llama_sampler_state {
float mirostat_mu; // mirostat sampler state
} llama_sampler_state;
struct gpt_params { struct gpt_params {
uint32_t seed = -1; // RNG seed uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores(); int32_t n_threads = get_num_physical_cores();
@ -54,6 +58,9 @@ struct gpt_params {
float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
// per sequence sampler state
std::unordered_map<llama_seq_id, llama_sampler_state> sampler_state;
// sampling parameters // sampling parameters
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
@ -186,6 +193,9 @@ std::string llama_detokenize_bpe(
// this is a common sampling function used across the examples for convenience // this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function // it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to delete
// the item in params.sampler_state when a sequence ends and samplers that rely on
// state are being used.
// //
// required: // required:
// - ctx: context to use for sampling // - ctx: context to use for sampling
@ -196,6 +206,7 @@ std::string llama_detokenize_bpe(
// - grammar: grammar to use for sampling, ignore if NULL // - grammar: grammar to use for sampling, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty // - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx) // - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with (currently only used by mirostat)
// //
// returns: // returns:
// - token: sampled token // - token: sampled token
@ -205,10 +216,11 @@ llama_token llama_sample_token(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_context * ctx_guidance, struct llama_context * ctx_guidance,
struct llama_grammar * grammar, struct llama_grammar * grammar,
const struct gpt_params & params, struct gpt_params & params,
const std::vector<llama_token> & last_tokens, const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates, std::vector<llama_token_data> & candidates,
int idx = 0); const int idx = 0,
llama_seq_id seq = 0);
// //
// YAML utils // YAML utils

View file

@ -339,7 +339,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n", //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
if (client.n_decoded == 1) { if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients // start measuring generation time after the first token to make sure all concurrent clients
@ -384,7 +384,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt; n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded; n_total_gen += client.n_decoded;
params.sampler_state.erase(client.seq_id);
client.seq_id = -1; client.seq_id = -1;
} }