added tests and fixed nsigma impl
This commit is contained in:
parent
8fb681bf9a
commit
54ef105c85
2 changed files with 28 additions and 13 deletions
|
@ -1655,36 +1655,32 @@ struct llama_sampler_top_n_sigma {
|
|||
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "top-n-sigma";
|
||||
}
|
||||
#include <iostream>
|
||||
|
||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
// 1. Find max logit: M
|
||||
// 2. Find standard deviation of logits: sig
|
||||
// 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0
|
||||
// 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf
|
||||
// 5. p = softmax(l)
|
||||
|
||||
// find max logit and calculate mean
|
||||
int32_t max = cur_p->data[0].logit;
|
||||
int32_t logits_sum = 0;
|
||||
float max = cur_p->data[0].logit;
|
||||
float logits_sum = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if(cur_p->data[i].logit > max){
|
||||
max = cur_p->data[i].logit;
|
||||
}
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
}
|
||||
int32_t mean = logits_sum/cur_p->size;
|
||||
float mean = (float)logits_sum/cur_p->size;
|
||||
|
||||
// calculate standard deviation
|
||||
int32_t acc = 0;
|
||||
float acc = 0;
|
||||
for(size_t i = 0; i < cur_p->size; ++i){
|
||||
acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean);
|
||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||
}
|
||||
int32_t std = sqrt(acc/cur_p->size);
|
||||
|
||||
float std = sqrt((float)acc/cur_p->size);
|
||||
|
||||
//apply mask
|
||||
for(size_t i = 0; i < cur_p->size; ++i){
|
||||
if(cur_p->data[i].logit < max - (ctx->n * std)) {
|
||||
if(cur_p->data[i].logit < max - ((float)ctx->n * std)) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue