Fix mirostat state when using multiple sequences
This commit is contained in:
parent
db3abcc114
commit
0e6db6fec1
3 changed files with 31 additions and 10 deletions
|
@ -940,10 +940,11 @@ llama_token llama_sample_token(
|
|||
struct llama_context * ctx,
|
||||
struct llama_context * ctx_guidance,
|
||||
struct llama_grammar * grammar,
|
||||
const struct gpt_params & params,
|
||||
struct gpt_params & params,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
std::vector<llama_token_data> & candidates,
|
||||
int idx) {
|
||||
const int idx,
|
||||
llama_seq_id seq) {
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
|
@ -1011,15 +1012,23 @@ llama_token llama_sample_token(
|
|||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx, &cur_p);
|
||||
} else {
|
||||
float * mirostat_mu = NULL;
|
||||
if (mirostat > 0) {
|
||||
seq = std::max(0, seq); // Deal with people passing -1 or something.
|
||||
auto mu_it = params.sampler_state.find(seq);
|
||||
if (mu_it == params.sampler_state.end()) {
|
||||
const llama_sampler_state new_state = { 2.0f * mirostat_tau };
|
||||
mu_it = params.sampler_state.insert({seq, new_state}).first;
|
||||
}
|
||||
mirostat_mu = &mu_it->second.mirostat_mu;
|
||||
}
|
||||
if (mirostat == 1) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
size_t min_keep = std::max(1, params.n_probs);
|
||||
|
|
|
@ -33,6 +33,10 @@
|
|||
//
|
||||
int32_t get_num_physical_cores();
|
||||
|
||||
typedef struct llama_sampler_state {
|
||||
float mirostat_mu; // mirostat sampler state
|
||||
} llama_sampler_state;
|
||||
|
||||
struct gpt_params {
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = get_num_physical_cores();
|
||||
|
@ -54,6 +58,9 @@ struct gpt_params {
|
|||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
|
||||
// per sequence sampler state
|
||||
std::unordered_map<llama_seq_id, llama_sampler_state> sampler_state;
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
|
@ -186,6 +193,9 @@ std::string llama_detokenize_bpe(
|
|||
|
||||
// this is a common sampling function used across the examples for convenience
|
||||
// it can serve as a starting point for implementing your own sampling function
|
||||
// Note: When using multiple sequences, it is the caller's responsibility to delete
|
||||
// the item in params.sampler_state when a sequence ends and samplers that rely on
|
||||
// state are being used.
|
||||
//
|
||||
// required:
|
||||
// - ctx: context to use for sampling
|
||||
|
@ -196,6 +206,7 @@ std::string llama_detokenize_bpe(
|
|||
// - grammar: grammar to use for sampling, ignore if NULL
|
||||
// - last_tokens: needed for repetition penalty, ignore if empty
|
||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||
// - seq: sequence id to associate sampler state with (currently only used by mirostat)
|
||||
//
|
||||
// returns:
|
||||
// - token: sampled token
|
||||
|
@ -205,10 +216,11 @@ llama_token llama_sample_token(
|
|||
struct llama_context * ctx,
|
||||
struct llama_context * ctx_guidance,
|
||||
struct llama_grammar * grammar,
|
||||
const struct gpt_params & params,
|
||||
struct gpt_params & params,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
std::vector<llama_token_data> & candidates,
|
||||
int idx = 0);
|
||||
const int idx = 0,
|
||||
llama_seq_id seq = 0);
|
||||
|
||||
//
|
||||
// YAML utils
|
||||
|
|
|
@ -339,7 +339,7 @@ int main(int argc, char ** argv) {
|
|||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
@ -384,7 +384,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
n_total_prompt += client.n_prompt;
|
||||
n_total_gen += client.n_decoded;
|
||||
|
||||
params.sampler_state.erase(client.seq_id);
|
||||
client.seq_id = -1;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue