perplexity : remove HellaSwag restruction for n_batch
This commit is contained in:
parent
64d173bc9c
commit
9df62c25f7
1 changed files with 32 additions and 5 deletions
|
@ -564,8 +564,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx");
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
const int max_tasks_per_batch = params.n_parallel;
|
||||
const int max_seq = 4*max_tasks_per_batch;
|
||||
|
@ -573,6 +572,34 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
std::vector<float> batch_logits(n_ctx*n_vocab);
|
||||
|
||||
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
if (ret != 0) {
|
||||
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||
return false;
|
||||
}
|
||||
|
||||
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
|
||||
int n_cur = 0;
|
||||
|
@ -622,7 +649,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (!decode_helper(ctx, batch, n_batch)) {
|
||||
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
@ -631,7 +658,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
for (size_t i = i0; i < i1; ++i) {
|
||||
auto & hs_cur = hs_data[i];
|
||||
|
||||
std::memcpy(tok_logits.data(), llama_get_logits_ith(ctx, hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
|
||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
|
||||
|
||||
const auto first_probs = softmax(tok_logits);
|
||||
|
||||
|
@ -643,7 +670,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
|
||||
// Calculate the logprobs over the ending
|
||||
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
||||
std::memcpy(tok_logits.data(), llama_get_logits_ith(ctx, hs_cur.i_batch + li++), n_vocab*sizeof(float));
|
||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
|
||||
|
||||
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue