perplexity : clean-up

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-01-18 13:51:02 +02:00
parent baa5279d02
commit 4351c4632d

View file

@ -591,23 +591,27 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx");
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch; const int max_seq = 4*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_batch, 0, max_seq); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
std::vector<std::vector<int>> ending_tokens(4);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
for (size_t i0 = 0; i0 < hs_task_count; i0++) { for (size_t i0 = 0; i0 < hs_task_count; i0++) {
int n_cur = 0; int n_cur = 0;
size_t i1 = i0; size_t i1 = i0;
size_t i_batch = 0; size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
llama_batch_clear(batch); llama_batch_clear(batch);
// batch as much tasks as possible into the available context // batch as much tasks as possible into the available context
// each task has 4 unique seuqnce ids
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1]; auto & hs_cur = hs_data[i1];
@ -649,6 +653,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
// evaluate the computed tasks
for (size_t i = i0; i < i1; ++i) { for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i]; auto & hs_cur = hs_data[i];