completed top nsigma sampler implementation
This commit is contained in:
parent
ddc3c2208a
commit
da038d8715
5 changed files with 112 additions and 79 deletions
|
@ -301,6 +301,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
cur_p->size = k;
|
||||
}
|
||||
|
||||
|
||||
static uint32_t get_rng_seed(uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
// use system clock if std::random_device is not a true RNG
|
||||
|
@ -1657,35 +1658,65 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
|
|||
|
||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
llama_sampler_top_n_sigma_impl(cur_p, ctx->n);
|
||||
// 1. Find max logit: M
|
||||
// 2. Find standard deviation of logits: sig
|
||||
// 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0
|
||||
// 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf
|
||||
// 5. p = softmax(l)
|
||||
|
||||
// find max logit and calculate mean
|
||||
int32_t max = cur_p->data[0].logit;
|
||||
int32_t logits_sum = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if(cur_p->data[i].logit > max){
|
||||
max = cur_p->data[i].logit;
|
||||
}
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
}
|
||||
int32_t mean = logits_sum/cur_p->size;
|
||||
|
||||
// calculate standard deviation
|
||||
int32_t acc = 0;
|
||||
for(size_t i = 0; i < cur_p->size; ++i){
|
||||
acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean);
|
||||
}
|
||||
int32_t std = sqrt(acc/cur_p->size);
|
||||
|
||||
//apply mask
|
||||
for(size_t i = 0; i < cur_p->size; ++i){
|
||||
if(cur_p->data[i].logit < max - (ctx->n * std)) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
}
|
||||
|
||||
// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
||||
// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
||||
// return llama_sampler_init_top_k(ctx->k);
|
||||
// }
|
||||
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl){
|
||||
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
return llama_sampler_init_top_n_sigma(ctx->n);
|
||||
}
|
||||
|
||||
// static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
||||
// delete (llama_sampler_top_k *) smpl->ctx;
|
||||
// }
|
||||
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
}
|
||||
|
||||
// static struct llama_sampler_i llama_sampler_top_k_i = {
|
||||
// /* .name = */ llama_sampler_top_k_name,
|
||||
// /* .accept = */ nullptr,
|
||||
// /* .apply = */ llama_sampler_top_k_apply,
|
||||
// /* .reset = */ nullptr,
|
||||
// /* .clone = */ llama_sampler_top_k_clone,
|
||||
// /* .free = */ llama_sampler_top_k_free,
|
||||
// };
|
||||
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||
};
|
||||
|
||||
// struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||
// return new llama_sampler {
|
||||
// /* .iface = */ &llama_sampler_top_k_i,
|
||||
// /* .ctx = */ new llama_sampler_top_k {
|
||||
// /* .k = */ k,
|
||||
// },
|
||||
// };
|
||||
// }
|
||||
struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||
/* .n = */ n,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// DRY
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue