llama : move random seed generation to the samplers (#9398)
* llama_sampler_penalties : clamp penalty_last_n to zero
This commit is contained in:
parent
00ba2ff781
commit
49006c67b4
10 changed files with 92 additions and 34 deletions
|
@ -8,6 +8,7 @@
|
|||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <cfloat>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
|
@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
cur_p->size = k;
|
||||
}
|
||||
|
||||
static uint32_t get_rng_seed(uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
// use system clock if std::random_device is not a true RNG
|
||||
static bool is_rd_prng = std::random_device().entropy() == 0;
|
||||
if (is_rd_prng) {
|
||||
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
|
||||
}
|
||||
std::random_device rd;
|
||||
return rd();
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
|
||||
// llama_sampler API
|
||||
|
||||
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
||||
|
@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
|
|||
|
||||
struct llama_sampler_dist {
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
|
|||
|
||||
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
ctx->rng = std::mt19937(ctx->seed);
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
||||
|
@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_dist_i,
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
/* .seed = */ seed,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
|
|||
const int32_t n_vocab;
|
||||
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
const float tau;
|
||||
const float eta;
|
||||
|
@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
|
|||
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
||||
ctx->mu = 2.0f*ctx->tau;
|
||||
ctx->rng = std::mt19937(ctx->seed);
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
||||
|
@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_mirostat_i,
|
||||
/* .ctx = */ new llama_sampler_mirostat {
|
||||
/* .n_vocab = */ n_vocab,
|
||||
/* .seed = */ seed,
|
||||
/* .tau = */ tau,
|
||||
/* .eta = */ eta,
|
||||
/* .m = */ m,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .n_vocab = */ n_vocab,
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .tau = */ tau,
|
||||
/* .eta = */ eta,
|
||||
/* .m = */ m,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
|
|||
|
||||
struct llama_sampler_mirostat_v2 {
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
const float tau;
|
||||
const float eta;
|
||||
|
@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
|
|||
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
||||
ctx->mu = 2.0f*ctx->tau;
|
||||
ctx->rng = std::mt19937(ctx->seed);
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
|
||||
|
@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
||||
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
||||
/* .seed = */ seed,
|
||||
/* .tau = */ tau,
|
||||
/* .eta = */ eta,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .tau = */ tau,
|
||||
/* .eta = */ eta,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
ignore_eos = false;
|
||||
}
|
||||
|
||||
penalty_last_n = std::max(penalty_last_n, 0);
|
||||
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_penalties_i,
|
||||
/* .ctx = */ new llama_sampler_penalties {
|
||||
|
@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
||||
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
||||
|
@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||
},
|
||||
};
|
||||
}
|
||||
|
||||
// utils
|
||||
|
||||
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
||||
if (smpl->iface == &llama_sampler_dist_i) {
|
||||
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
|
||||
}
|
||||
|
||||
if (smpl->iface == &llama_sampler_mirostat_i) {
|
||||
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
|
||||
}
|
||||
|
||||
if (smpl->iface == &llama_sampler_mirostat_v2_i) {
|
||||
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
|
||||
}
|
||||
|
||||
if (smpl->iface == &llama_sampler_chain_i) {
|
||||
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
||||
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
||||
const uint32_t seed = llama_sampler_get_seed(*it);
|
||||
if (seed != LLAMA_DEFAULT_SEED) {
|
||||
return seed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return LLAMA_DEFAULT_SEED;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue