diff --git a/common/sampling.cpp b/common/sampling.cpp index fd77e7bf6..0c35044e9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -185,7 +185,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); break; case GPT_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep, params.seed)); break; case GPT_SAMPLER_TYPE_TFS_Z: llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9858f0a3d..4372b40c3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1062,10 +1062,16 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa // xtc struct llama_sampler_xtc { - const float probability; - const float threshold; - const float threshold_max; - const size_t min_keep; + const float probability; + const float threshold; + const float threshold_max; + const size_t min_keep; + + const uint32_t seed; + uint32_t seed_cur; + float chance; + + std::mt19937 rng; }; static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { @@ -1084,10 +1090,8 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data || ctx->min_keep <= 2) { return; } - - std::random_device rd; - float chance = (float)(rd()%100 - 1)/100; - if (chance > ctx->probability) return; + // chance is calculated on init and on each reset + if (ctx->chance > ctx->probability) return; // in case it's not sorted/recalculated yet llama_sampler_softmax_impl(cur_p); @@ -1117,23 +1121,45 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; - return llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep); + auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep, ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_xtc *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; } static void llama_sampler_xtc_free(struct llama_sampler * smpl) { delete (llama_sampler_xtc *) smpl->ctx; } +static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); + + std::uniform_real_distribution<> distance(0.0, 1.0); + ctx->chance = distance(ctx->rng); +} + static struct llama_sampler_i llama_sampler_xtc_i = { /* .name = */ llama_sampler_xtc_name, /* .accept = */ nullptr, /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ nullptr, + /* .reset = */ llama_sampler_xtc_reset, /* .clone = */ llama_sampler_xtc_clone, /* .free = */ llama_sampler_xtc_free, }; -struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep) { +struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + std::uniform_real_distribution<> distance(0.0, 1.0); + auto rng = std::mt19937(seed_cur); + float chance = distance(rng); return new llama_sampler { /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { @@ -1141,6 +1167,10 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz /* .threshold = */ t, /* .threshold_max = */ t_max, /* .min_keep = */ min_keep, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .chance = */ chance, + /* .rng = */ rng, }, }; }