perplexity : no need for decode_helper
ggml-ci
This commit is contained in:
parent
4351c4632d
commit
af309010ab
1 changed files with 1 additions and 29 deletions
|
@ -442,33 +442,6 @@ static std::vector<float> hellaswag_evaluate_tokens(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// decode in batches of ctx_params.n_batch tokens
|
|
||||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
|
||||||
|
|
||||||
llama_batch batch_view = {
|
|
||||||
n_tokens,
|
|
||||||
batch.token + i,
|
|
||||||
nullptr,
|
|
||||||
batch.pos + i,
|
|
||||||
batch.n_seq_id + i,
|
|
||||||
batch.seq_id + i,
|
|
||||||
batch.logits + i,
|
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
|
||||||
if (ret != 0) {
|
|
||||||
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
// Calculates hellaswag score (acc_norm) from prompt
|
// Calculates hellaswag score (acc_norm) from prompt
|
||||||
//
|
//
|
||||||
|
@ -589,7 +562,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
const int n_batch = params.n_batch;
|
|
||||||
|
|
||||||
GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx");
|
GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx");
|
||||||
|
|
||||||
|
@ -648,7 +620,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
if (!decode_helper(ctx, batch, n_batch)) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue