Fixed RNG to be reproduceable
Thanks to @slaren for directions
This commit is contained in:
parent
f2a2a618a2
commit
4f8e55b170
2 changed files with 42 additions and 12 deletions
|
@ -185,7 +185,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case GPT_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep, params.seed));
|
||||
break;
|
||||
case GPT_SAMPLER_TYPE_TFS_Z:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||
|
|
|
@ -1062,10 +1062,16 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|||
// xtc
|
||||
|
||||
struct llama_sampler_xtc {
|
||||
const float probability;
|
||||
const float threshold;
|
||||
const float threshold_max;
|
||||
const size_t min_keep;
|
||||
const float probability;
|
||||
const float threshold;
|
||||
const float threshold_max;
|
||||
const size_t min_keep;
|
||||
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
float chance;
|
||||
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1084,10 +1090,8 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
|| ctx->min_keep <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
float chance = (float)(rd()%100 - 1)/100;
|
||||
if (chance > ctx->probability) return;
|
||||
// chance is calculated on init and on each reset
|
||||
if (ctx->chance > ctx->probability) return;
|
||||
|
||||
// in case it's not sorted/recalculated yet
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
@ -1117,23 +1121,45 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
|
||||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
||||
return llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep);
|
||||
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep, ctx->seed);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
|
||||
|
||||
result_ctx->rng = ctx->rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_xtc *) smpl->ctx;
|
||||
}
|
||||
|
||||
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
|
||||
std::uniform_real_distribution<> distance(0.0, 1.0);
|
||||
ctx->chance = distance(ctx->rng);
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||
/* .name = */ llama_sampler_xtc_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sample_xtc_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .reset = */ llama_sampler_xtc_reset,
|
||||
/* .clone = */ llama_sampler_xtc_clone,
|
||||
/* .free = */ llama_sampler_xtc_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep) {
|
||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
std::uniform_real_distribution<> distance(0.0, 1.0);
|
||||
auto rng = std::mt19937(seed_cur);
|
||||
float chance = distance(rng);
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_xtc_i,
|
||||
/* .ctx = */ new llama_sampler_xtc {
|
||||
|
@ -1141,6 +1167,10 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
|
|||
/* .threshold = */ t,
|
||||
/* .threshold_max = */ t_max,
|
||||
/* .min_keep = */ min_keep,
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .chance = */ chance,
|
||||
/* .rng = */ rng,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue