diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d7a18e70e..8349b3162 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1069,9 +1069,8 @@ struct llama_sampler_xtc { const uint32_t seed; uint32_t seed_cur; - float chance; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { @@ -1079,7 +1078,7 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/ } static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_xtc *) smpl->ctx; + auto * ctx = (llama_sampler_xtc *) smpl->ctx; if (ctx->probability <= 0.0f || ctx->threshold <= 0.0f @@ -1089,8 +1088,10 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data || cur_p->size <= 2) { return; } - // chance is calculated on init and on each reset - if (ctx->chance > ctx->probability) return; + + std::uniform_real_distribution distance(0.0f, 1.0f); + float chance = distance(ctx->rng); + if (chance > ctx->probability) return; // in case it's not sorted/recalculated yet llama_sampler_softmax_impl(cur_p); @@ -1142,9 +1143,6 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_xtc *) smpl->ctx; ctx->seed_cur = get_rng_seed(ctx->seed); ctx->rng.seed(ctx->seed_cur); - - std::uniform_real_distribution<> distance(0.0, 1.0); - ctx->chance = distance(ctx->rng); } static struct llama_sampler_i llama_sampler_xtc_i = { @@ -1158,9 +1156,6 @@ static struct llama_sampler_i llama_sampler_xtc_i = { struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) { auto seed_cur = get_rng_seed(seed); - std::uniform_real_distribution<> distance(0.0, 1.0); - auto rng = std::mt19937(seed_cur); - float chance = distance(rng); return new llama_sampler { /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { @@ -1170,8 +1165,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz /* .min_keep = */ min_keep, /* .seed = */ seed, /* .seed_cur = */ seed_cur, - /* .chance = */ chance, - /* .rng = */ rng, + /* .rng = */ std::mt19937(seed_cur), }, }; }