avoid relying on 'logits_all == true' in perplexity_v2
This commit is contained in:
parent
cbb5dd7b12
commit
b0c6ad778d
1 changed files with 22 additions and 10 deletions
|
@ -367,17 +367,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
const int calc_chunk = n_ctx;
|
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), n_ctx);
|
||||||
|
|
||||||
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
|
if (int(tokens.size()) <= n_ctx) {
|
||||||
|
|
||||||
if (int(tokens.size()) <= calc_chunk) {
|
|
||||||
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
|
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
|
||||||
tokens.size(), n_ctx, params.ppl_stride);
|
tokens.size(), n_ctx, params.ppl_stride);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
|
const int n_chunk_max = (tokens.size() - n_ctx + params.ppl_stride - 1) / params.ppl_stride;
|
||||||
|
|
||||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
@ -386,13 +384,13 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
|
|
||||||
|
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * params.ppl_stride;
|
const int start = i * params.ppl_stride;
|
||||||
const int end = start + calc_chunk;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
|
|
||||||
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
|
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
@ -406,13 +404,27 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(batch_size, 0, 1);
|
||||||
|
for (int k = 0; k < batch_size; ++k) {
|
||||||
|
const int idx = batch_start + k;
|
||||||
|
batch.token [k] = tokens[idx];
|
||||||
|
batch.output [k] = 1;
|
||||||
|
}
|
||||||
|
batch.n_tokens = batch_size;
|
||||||
|
batch.pos = nullptr;
|
||||||
|
batch.n_seq_id = nullptr;
|
||||||
|
batch.seq_id = nullptr;
|
||||||
|
batch.all_pos_0 = j*n_batch;
|
||||||
|
batch.all_pos_1 = 1;
|
||||||
|
batch.all_seq_id = 0;
|
||||||
|
|
||||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
// TODO: use llama_batch.output instead of relying on logits_all == true
|
if (llama_decode(ctx, batch)) {
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
||||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
// save original token and restore it after eval
|
// save original token and restore it after eval
|
||||||
const auto token_org = tokens[batch_start];
|
const auto token_org = tokens[batch_start];
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue