Fixed RNG to be reproduceable
Thanks to @slaren for directions
This commit is contained in:
parent
f2a2a618a2
commit
4f8e55b170
2 changed files with 42 additions and 12 deletions
|
@ -185,7 +185,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case GPT_SAMPLER_TYPE_XTC:
|
case GPT_SAMPLER_TYPE_XTC:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case GPT_SAMPLER_TYPE_TFS_Z:
|
case GPT_SAMPLER_TYPE_TFS_Z:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||||
|
|
|
@ -1062,10 +1062,16 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
// xtc
|
// xtc
|
||||||
|
|
||||||
struct llama_sampler_xtc {
|
struct llama_sampler_xtc {
|
||||||
const float probability;
|
const float probability;
|
||||||
const float threshold;
|
const float threshold;
|
||||||
const float threshold_max;
|
const float threshold_max;
|
||||||
const size_t min_keep;
|
const size_t min_keep;
|
||||||
|
|
||||||
|
const uint32_t seed;
|
||||||
|
uint32_t seed_cur;
|
||||||
|
float chance;
|
||||||
|
|
||||||
|
std::mt19937 rng;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
@ -1084,10 +1090,8 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
|| ctx->min_keep <= 2) {
|
|| ctx->min_keep <= 2) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// chance is calculated on init and on each reset
|
||||||
std::random_device rd;
|
if (ctx->chance > ctx->probability) return;
|
||||||
float chance = (float)(rd()%100 - 1)/100;
|
|
||||||
if (chance > ctx->probability) return;
|
|
||||||
|
|
||||||
// in case it's not sorted/recalculated yet
|
// in case it's not sorted/recalculated yet
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
@ -1117,23 +1121,45 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
||||||
return llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep);
|
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep, ctx->seed);
|
||||||
|
|
||||||
|
// copy the state
|
||||||
|
{
|
||||||
|
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
|
||||||
|
|
||||||
|
result_ctx->rng = ctx->rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_xtc *) smpl->ctx;
|
delete (llama_sampler_xtc *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||||
|
ctx->rng.seed(ctx->seed_cur);
|
||||||
|
|
||||||
|
std::uniform_real_distribution<> distance(0.0, 1.0);
|
||||||
|
ctx->chance = distance(ctx->rng);
|
||||||
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_xtc_i = {
|
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
/* .name = */ llama_sampler_xtc_name,
|
/* .name = */ llama_sampler_xtc_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ llama_sample_xtc_apply,
|
/* .apply = */ llama_sample_xtc_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ llama_sampler_xtc_reset,
|
||||||
/* .clone = */ llama_sampler_xtc_clone,
|
/* .clone = */ llama_sampler_xtc_clone,
|
||||||
/* .free = */ llama_sampler_xtc_free,
|
/* .free = */ llama_sampler_xtc_free,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
|
||||||
|
auto seed_cur = get_rng_seed(seed);
|
||||||
|
std::uniform_real_distribution<> distance(0.0, 1.0);
|
||||||
|
auto rng = std::mt19937(seed_cur);
|
||||||
|
float chance = distance(rng);
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ &llama_sampler_xtc_i,
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
/* .ctx = */ new llama_sampler_xtc {
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
|
@ -1141,6 +1167,10 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, siz
|
||||||
/* .threshold = */ t,
|
/* .threshold = */ t,
|
||||||
/* .threshold_max = */ t_max,
|
/* .threshold_max = */ t_max,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
|
/* .seed = */ seed,
|
||||||
|
/* .seed_cur = */ seed_cur,
|
||||||
|
/* .chance = */ chance,
|
||||||
|
/* .rng = */ rng,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue