Fixed broken randomization

Thanks to @slaren for explanation
This commit is contained in:
MaggotHATE 2024-10-04 23:47:19 +05:00 committed by GitHub
parent 899e0732ee
commit 74f657cc24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1069,9 +1069,8 @@ struct llama_sampler_xtc {
const uint32_t seed; const uint32_t seed;
uint32_t seed_cur; uint32_t seed_cur;
float chance;
std::mt19937 rng; std::mt19937 rng;
}; };
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@ -1079,7 +1078,7 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
} }
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_xtc *) smpl->ctx; auto * ctx = (llama_sampler_xtc *) smpl->ctx;
if (ctx->probability <= 0.0f if (ctx->probability <= 0.0f
|| ctx->threshold <= 0.0f || ctx->threshold <= 0.0f
@ -1089,8 +1088,10 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|| cur_p->size <= 2) { || cur_p->size <= 2) {
return; return;
} }
// chance is calculated on init and on each reset
if (ctx->chance > ctx->probability) return; std::uniform_real_distribution<float> distance(0.0f, 1.0f);
float chance = distance(ctx->rng);
if (chance > ctx->probability) return;
// in case it's not sorted/recalculated yet // in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p); llama_sampler_softmax_impl(cur_p);
@ -1142,9 +1143,6 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_xtc *) smpl->ctx; auto * ctx = (llama_sampler_xtc *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed); ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur); ctx->rng.seed(ctx->seed_cur);
std::uniform_real_distribution<> distance(0.0, 1.0);
ctx->chance = distance(ctx->rng);
} }
static struct llama_sampler_i llama_sampler_xtc_i = { static struct llama_sampler_i llama_sampler_xtc_i = {
@ -1158,9 +1156,6 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) { struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
std::uniform_real_distribution<> distance(0.0, 1.0);
auto rng = std::mt19937(seed_cur);
float chance = distance(rng);
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_xtc_i, /* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc { /* .ctx = */ new llama_sampler_xtc {
@ -1170,8 +1165,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
/* .seed = */ seed, /* .seed = */ seed,
/* .seed_cur = */ seed_cur, /* .seed_cur = */ seed_cur,
/* .chance = */ chance, /* .rng = */ std::mt19937(seed_cur),
/* .rng = */ rng,
}, },
}; };
} }