Fixed broken randomization
Thanks to @slaren for explanation
This commit is contained in:
parent
899e0732ee
commit
74f657cc24
1 changed files with 7 additions and 13 deletions
|
@ -1069,9 +1069,8 @@ struct llama_sampler_xtc {
|
|||
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
float chance;
|
||||
|
||||
std::mt19937 rng;
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1079,7 +1078,7 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
|
|||
}
|
||||
|
||||
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
|
||||
if (ctx->probability <= 0.0f
|
||||
|| ctx->threshold <= 0.0f
|
||||
|
@ -1089,8 +1088,10 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
|| cur_p->size <= 2) {
|
||||
return;
|
||||
}
|
||||
// chance is calculated on init and on each reset
|
||||
if (ctx->chance > ctx->probability) return;
|
||||
|
||||
std::uniform_real_distribution<float> distance(0.0f, 1.0f);
|
||||
float chance = distance(ctx->rng);
|
||||
if (chance > ctx->probability) return;
|
||||
|
||||
// in case it's not sorted/recalculated yet
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
@ -1142,9 +1143,6 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
|||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
|
||||
std::uniform_real_distribution<> distance(0.0, 1.0);
|
||||
ctx->chance = distance(ctx->rng);
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||
|
@ -1158,9 +1156,6 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
|||
|
||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
std::uniform_real_distribution<> distance(0.0, 1.0);
|
||||
auto rng = std::mt19937(seed_cur);
|
||||
float chance = distance(rng);
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_xtc_i,
|
||||
/* .ctx = */ new llama_sampler_xtc {
|
||||
|
@ -1170,8 +1165,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
|
|||
/* .min_keep = */ min_keep,
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .chance = */ chance,
|
||||
/* .rng = */ rng,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue