diff --git a/llama.cpp b/llama.cpp index 920689d3f..3fef34d18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8001,71 +8001,56 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - constexpr int nbuckets = 100; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; + if (k < 200) { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } else { + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucker_inter = -bucket_low * bucket_scale; - std::vector tmp_tokens(candidates->size); - std::vector bucket_idx(candidates->size); - std::vector histo(nbuckets, 0); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets); + std::vector tmp_tokens(candidates->size); + std::vector bucket_idx(candidates->size); + std::vector histo(nbuckets, 0); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets); - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; - int ib = nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; - } - auto ptr = tmp_tokens.data(); - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - bucket_ptrs.push_back(ptr); - if (nhave >= k) break; - ptr += histo[ib]; - } - for (int i = 0; i < (int)candidates->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + for (int i = 0; i < (int)candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets-1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; } + auto ptr = tmp_tokens.data(); + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + bucket_ptrs.push_back(ptr); + if (nhave >= k) break; + ptr += histo[ib]; + } + for (int i = 0; i < (int)candidates->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets-1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - - //std::vector buckets[nbuckets]; - - //for (size_t i = 0; i < candidates->size; ++i) { - // const float val = candidates->data[i].logit; - // int ib = nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - // ib = std::max(ib, 0); - // ib = std::min(ib, nbuckets-1); - // buckets[ib].push_back(candidates->data[i]); - //} - - //int nsorted = 0; - //for (int ib = nbuckets-1; ib >= 0; --ib) { - // std::sort(buckets[ib].begin(), buckets[ib].end(), comp); - // memcpy(candidates->data + nsorted, buckets[ib].data(), buckets[ib].size()*sizeof(llama_token_data)); - - // nsorted += buckets[ib].size(); - - // if (nsorted >= k) { - // break; - // } - //} candidates->sorted = true; } candidates->size = k;