Fixed broken randomization
Thanks to @slaren for explanation
This commit is contained in:
parent
899e0732ee
commit
74f657cc24
1 changed files with 7 additions and 13 deletions
|
@ -1069,9 +1069,8 @@ struct llama_sampler_xtc {
|
||||||
|
|
||||||
const uint32_t seed;
|
const uint32_t seed;
|
||||||
uint32_t seed_cur;
|
uint32_t seed_cur;
|
||||||
float chance;
|
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
@ -1079,7 +1078,7 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx->probability <= 0.0f
|
if (ctx->probability <= 0.0f
|
||||||
|| ctx->threshold <= 0.0f
|
|| ctx->threshold <= 0.0f
|
||||||
|
@ -1089,8 +1088,10 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
|| cur_p->size <= 2) {
|
|| cur_p->size <= 2) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// chance is calculated on init and on each reset
|
|
||||||
if (ctx->chance > ctx->probability) return;
|
std::uniform_real_distribution<float> distance(0.0f, 1.0f);
|
||||||
|
float chance = distance(ctx->rng);
|
||||||
|
if (chance > ctx->probability) return;
|
||||||
|
|
||||||
// in case it's not sorted/recalculated yet
|
// in case it's not sorted/recalculated yet
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
@ -1142,9 +1143,6 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||||
ctx->rng.seed(ctx->seed_cur);
|
ctx->rng.seed(ctx->seed_cur);
|
||||||
|
|
||||||
std::uniform_real_distribution<> distance(0.0, 1.0);
|
|
||||||
ctx->chance = distance(ctx->rng);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_xtc_i = {
|
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
|
@ -1158,9 +1156,6 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
std::uniform_real_distribution<> distance(0.0, 1.0);
|
|
||||||
auto rng = std::mt19937(seed_cur);
|
|
||||||
float chance = distance(rng);
|
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ &llama_sampler_xtc_i,
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
/* .ctx = */ new llama_sampler_xtc {
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
|
@ -1170,8 +1165,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
/* .seed_cur = */ seed_cur,
|
/* .seed_cur = */ seed_cur,
|
||||||
/* .chance = */ chance,
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
/* .rng = */ rng,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue