Fixed RNG to be reproduceable

Thanks to @slaren for directions
This commit is contained in:
MaggotHATE 2024-10-04 22:38:12 +05:00 committed by GitHub
parent f2a2a618a2
commit 4f8e55b170
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 12 deletions

View file

@ -185,7 +185,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case GPT_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep, params.seed));
break;
case GPT_SAMPLER_TYPE_TFS_Z:
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));

View file

@ -1062,10 +1062,16 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
// xtc
struct llama_sampler_xtc {
const float probability;
const float threshold;
const float threshold_max;
const size_t min_keep;
const float probability;
const float threshold;
const float threshold_max;
const size_t min_keep;
const uint32_t seed;
uint32_t seed_cur;
float chance;
std::mt19937 rng;
};
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@ -1084,10 +1090,8 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|| ctx->min_keep <= 2) {
return;
}
std::random_device rd;
float chance = (float)(rd()%100 - 1)/100;
if (chance > ctx->probability) return;
// chance is calculated on init and on each reset
if (ctx->chance > ctx->probability) return;
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
@ -1117,23 +1121,45 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
return llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep);
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep, ctx->seed);
// copy the state
{
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
result_ctx->rng = ctx->rng;
}
return result;
}
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
delete (llama_sampler_xtc *) smpl->ctx;
}
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
std::uniform_real_distribution<> distance(0.0, 1.0);
ctx->chance = distance(ctx->rng);
}
static struct llama_sampler_i llama_sampler_xtc_i = {
/* .name = */ llama_sampler_xtc_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sample_xtc_apply,
/* .reset = */ nullptr,
/* .reset = */ llama_sampler_xtc_reset,
/* .clone = */ llama_sampler_xtc_clone,
/* .free = */ llama_sampler_xtc_free,
};
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep) {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
std::uniform_real_distribution<> distance(0.0, 1.0);
auto rng = std::mt19937(seed_cur);
float chance = distance(rng);
return new llama_sampler {
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
@ -1141,6 +1167,10 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
/* .threshold = */ t,
/* .threshold_max = */ t_max,
/* .min_keep = */ min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .chance = */ chance,
/* .rng = */ rng,
},
};
}