perplexity : add BOS for each chunk
This commit is contained in:
parent
cad6ff5d36
commit
7f33230a40
1 changed files with 34 additions and 22 deletions
|
@ -25,46 +25,56 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||||
// Output: `perplexity: 13.5106 [114/114]`
|
// Output: `perplexity: 13.5106 [114/114]`
|
||||||
|
// BOS tokens will be added for each chunk before eval
|
||||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
int seq_count = tokens.size() / params.n_ctx;
|
|
||||||
int n_vocab = llama_n_vocab(ctx);
|
const int n_chunk = tokens.size() / params.n_ctx;
|
||||||
|
const int n_vocab = llama_n_vocab(ctx);
|
||||||
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
for (int i = 0; i < seq_count; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * params.n_ctx;
|
const int start = i * params.n_ctx;
|
||||||
const int end = start + params.n_ctx;
|
const int end = start + params.n_ctx;
|
||||||
|
|
||||||
std::vector<float> logits;
|
const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
|
||||||
const int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
|
|
||||||
|
|
||||||
const auto start_t = std::chrono::high_resolution_clock::now();
|
std::vector<float> logits;
|
||||||
|
|
||||||
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * params.n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, params.n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
// TODO: not perfect since this can be in the middle of a word, but it is better than nothing
|
// save original token and restore it after eval
|
||||||
tokens[batch_start] = llama_token_bos();
|
const auto token_org = tokens[batch_start];
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
|
if (j == 0) {
|
||||||
|
tokens[batch_start] = llama_token_bos();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokens[batch_start] = token_org;
|
||||||
|
|
||||||
const auto batch_logits = llama_get_logits(ctx);
|
const auto batch_logits = llama_get_logits(ctx);
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto end_t = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
const float seconds = std::chrono::duration<float>(end_t - start_t).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, seconds);
|
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(seconds * seq_count);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
if (total_seconds >= 60*60) {
|
if (total_seconds >= 60*60) {
|
||||||
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
||||||
total_seconds = total_seconds % (60*60);
|
total_seconds = total_seconds % (60*60);
|
||||||
|
@ -74,7 +84,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
// We get the logits for all the tokens in the context window (params.n_ctx)
|
// We get the logits for all the tokens in the context window (params.n_ctx)
|
||||||
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
||||||
// calculate the perplexity over the last half the window (so the model always has
|
// calculate the perplexity over the last half of the window (so the model always has
|
||||||
// some context to predict the token).
|
// some context to predict the token).
|
||||||
//
|
//
|
||||||
// We rely on the fact that attention in the forward pass only looks at previous
|
// We rely on the fact that attention in the forward pass only looks at previous
|
||||||
|
@ -86,10 +96,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
// process the entire prompt.
|
// process the entire prompt.
|
||||||
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
|
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
|
||||||
// Calculate probability of next token, given the previous ones.
|
// Calculate probability of next token, given the previous ones.
|
||||||
std::vector<float> tok_logits(
|
const std::vector<float> tok_logits(
|
||||||
logits.begin() + j * n_vocab,
|
logits.begin() + (j + 0) * n_vocab,
|
||||||
logits.begin() + (j + 1) * n_vocab);
|
logits.begin() + (j + 1) * n_vocab);
|
||||||
float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
|
||||||
|
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
||||||
|
|
||||||
nll += -std::log(prob);
|
nll += -std::log(prob);
|
||||||
++count;
|
++count;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue