Fixed broken randomization

Thanks to @slaren for explanation
This commit is contained in:
MaggotHATE 2024-10-04 23:47:19 +05:00 committed by GitHub
parent 899e0732ee
commit 74f657cc24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1069,9 +1069,8 @@ struct llama_sampler_xtc {
const uint32_t seed;
uint32_t seed_cur;
float chance;
std::mt19937 rng;
std::mt19937 rng;
};
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@ -1079,7 +1078,7 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
}
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
if (ctx->probability <= 0.0f
|| ctx->threshold <= 0.0f
@ -1089,8 +1088,10 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|| cur_p->size <= 2) {
return;
}
// chance is calculated on init and on each reset
if (ctx->chance > ctx->probability) return;
std::uniform_real_distribution<float> distance(0.0f, 1.0f);
float chance = distance(ctx->rng);
if (chance > ctx->probability) return;
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
@ -1142,9 +1143,6 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
std::uniform_real_distribution<> distance(0.0, 1.0);
ctx->chance = distance(ctx->rng);
}
static struct llama_sampler_i llama_sampler_xtc_i = {
@ -1158,9 +1156,6 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
std::uniform_real_distribution<> distance(0.0, 1.0);
auto rng = std::mt19937(seed_cur);
float chance = distance(rng);
return new llama_sampler {
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
@ -1170,8 +1165,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
/* .min_keep = */ min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .chance = */ chance,
/* .rng = */ rng,
/* .rng = */ std::mt19937(seed_cur),
},
};
}