diff --git a/llama.cpp b/llama.cpp index 582e82260..920689d3f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8001,11 +8001,71 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - if (k == (int) candidates->size) { - std::sort(candidates->data, candidates->data + candidates->size, comp); - } else { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + constexpr int nbuckets = 100; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + + std::vector tmp_tokens(candidates->size); + std::vector bucket_idx(candidates->size); + std::vector histo(nbuckets, 0); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets); + + for (int i = 0; i < (int)candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets-1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; } + auto ptr = tmp_tokens.data(); + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + bucket_ptrs.push_back(ptr); + if (nhave >= k) break; + ptr += histo[ib]; + } + for (int i = 0; i < (int)candidates->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets-1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + //std::vector buckets[nbuckets]; + + //for (size_t i = 0; i < candidates->size; ++i) { + // const float val = candidates->data[i].logit; + // int ib = nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + // ib = std::max(ib, 0); + // ib = std::min(ib, nbuckets-1); + // buckets[ib].push_back(candidates->data[i]); + //} + + //int nsorted = 0; + //for (int ib = nbuckets-1; ib >= 0; --ib) { + // std::sort(buckets[ib].begin(), buckets[ib].end(), comp); + // memcpy(candidates->data + nsorted, buckets[ib].data(), buckets[ib].size()*sizeof(llama_token_data)); + + // nsorted += buckets[ib].size(); + + // if (nsorted >= k) { + // break; + // } + //} candidates->sorted = true; } candidates->size = k;