calc ppl on sakurallm prompt format correctly

This commit is contained in:
Reinforce-II 2024-02-09 21:46:27 +08:00
parent 7c777fcd5d
commit 2db0ca34d3
2 changed files with 107 additions and 241 deletions

View file

@ -323,139 +323,87 @@ static void process_logits(
}
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
(void)from_chunk;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (from_chunk > 0) {
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
return false;
}
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
}
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return false;
}
std::vector<float> logit_history;
std::vector<float> prob_history;
if (compute_ppl) {
logit_history.resize(tokens.size());
prob_history.resize(tokens.size());
}
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
double nll = 0.0;
double nll2 = 0.0;
fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
std::vector<llama_token> tokens;
std::vector<float> logit_history;
std::vector<float> prob_history;
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
if (compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
size_t c_begin = 0;
while (true) {
const char* s_begin = "<|im_start|>system\n";
const char* s_assistant = "<|im_start|>assistant\n";
c_begin = params.prompt.find(s_begin, c_begin);
if (c_begin == std::string::npos) {
break;
}
size_t c_assistant = params.prompt.find(s_assistant, c_begin);
if (c_assistant == std::string::npos) {
break;
}
c_assistant += strlen(s_assistant);
size_t next_c_begin = params.prompt.find(s_begin, c_assistant);
auto s_prompt = params.prompt.substr(c_begin, c_assistant - c_begin);
auto s_response = params.prompt.substr(c_assistant, next_c_begin - c_assistant);
c_begin += 1;
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
std::vector<llama_token> s_tokens_prompt = ::llama_tokenize(ctx, s_prompt, false);
std::vector<llama_token> s_tokens_response = ::llama_tokenize(ctx, s_response, false);
std::vector<llama_token> s_tokens = s_tokens_prompt;
s_tokens.insert(s_tokens.end(), s_tokens_response.begin(), s_tokens_response.end());
std::vector<float> s_logits;
std::vector<float> s_logit_history(s_tokens.size(), 0);
std::vector<float> s_prob_history(s_tokens.size(), 0);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
for (int j = 0; j < (int(s_tokens.size()) + n_batch - 1) / n_batch; ++j) {
const int batch_start = j * n_batch;
const int batch_size = std::min((int)s_tokens.size() - batch_start, n_batch);
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
if (llama_decode(ctx, llama_batch_get_one(s_tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (compute_ppl && num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
const auto* batch_logits = llama_get_logits(ctx);
s_logits.insert(s_logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
const auto t_end = std::chrono::high_resolution_clock::now();
const int first = s_tokens_prompt.size();
const float* all_logits = s_logits.data();
process_logits(n_vocab, all_logits + first * n_vocab, s_tokens.data() + first, s_tokens_response.size() - 1,
workers, nll, nll2, s_logit_history.data() + first, s_prob_history.data() + first);
count += s_tokens_response.size() - 1;
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
if (compute_ppl) {
const int first = n_ctx/2;
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
printf(" %.4lf,", std::exp(nll / count));
fflush(stdout);
logits.clear();
tokens.insert(tokens.end(), s_tokens.begin(), s_tokens.end());
logit_history.insert(logit_history.end(), s_logit_history.begin(), s_logit_history.end());
prob_history.insert(prob_history.end(), s_prob_history.begin(), s_prob_history.end());
}
}
printf("\n");
if (compute_ppl) {
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
printf("Unexpected negative standard deviation of log(prob)\n");
nll2 = sqrt(nll2 / (count - 1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2 * ppl);
}
else {
printf("Unexpected negative standard deviation of log(prob)\n");
}
return true;

View file

@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, std::exp(nll / count), logit_history, prob_history};
}
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
static results_perplexity perplexity(llama_context* ctx, const gpt_params& params) {
if (params.ppl_stride > 0) {
return perplexity_v2(ctx, params);
}
@ -452,170 +452,88 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
std::ofstream logits_stream;
if (!params.logits_file.empty()) {
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
if (!logits_stream.is_open()) {
fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
return {};
}
fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
logits_stream.write("_logits_", 8);
logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
}
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
std::vector<float> logit_history;
logit_history.resize(tokens.size());
std::vector<float> prob_history;
prob_history.resize(tokens.size());
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
double nll = 0.0;
double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<llama_token> tokens;
std::vector<float> logit_history;
std::vector<float> prob_history;
std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
size_t c_begin = 0;
while (true) {
const char* s_begin = "<|im_start|>system\n";
const char* s_assistant = "<|im_start|>assistant\n";
c_begin = params.prompt.find(s_begin, c_begin);
if (c_begin == std::string::npos) {
break;
}
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
std::vector<uint16_t> log_probs;
if (!params.logits_file.empty()) {
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
const int nv = 2*((n_vocab + 1)/2) + 4;
log_probs.resize(n_ctx * nv);
size_t c_assistant = params.prompt.find(s_assistant, c_begin);
if (c_assistant == std::string::npos) {
break;
}
c_assistant += strlen(s_assistant);
size_t next_c_begin = params.prompt.find(s_begin, c_assistant);
auto s_prompt = params.prompt.substr(c_begin, c_assistant - c_begin);
auto s_response = params.prompt.substr(c_assistant, next_c_begin - c_assistant);
c_begin += 1;
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
std::vector<llama_token> s_tokens_prompt = ::llama_tokenize(ctx, s_prompt, false);
std::vector<llama_token> s_tokens_response = ::llama_tokenize(ctx, s_response, false);
std::vector<llama_token> s_tokens = s_tokens_prompt;
s_tokens.insert(s_tokens.end(), s_tokens_response.begin(), s_tokens_response.end());
std::vector<float> s_logits;
std::vector<float> s_logit_history(s_tokens.size(), 0);
std::vector<float> s_prob_history(s_tokens.size(), 0);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
for (int j = 0; j < (int(s_tokens.size()) + n_batch - 1) / n_batch; ++j) {
const int batch_start = j * n_batch;
const int batch_size = std::min((int)s_tokens.size() - batch_start, n_batch);
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
if (llama_decode(ctx, llama_batch_get_one(s_tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
return { tokens, -1, logit_history, prob_history };
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
const auto* batch_logits = llama_get_logits(ctx);
s_logits.insert(s_logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
const auto t_end = std::chrono::high_resolution_clock::now();
const int first = s_tokens_prompt.size();
const float* all_logits = s_logits.data();
process_logits(n_vocab, all_logits + first * n_vocab, s_tokens.data() + first, s_tokens_response.size() - 1,
workers, nll, nll2, s_logit_history.data() + first, s_prob_history.data() + first);
count += s_tokens_response.size() - 1;
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
// We get the logits for all the tokens in the context window (params.n_ctx)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
// calculate the perplexity over the last half of the window (so the model always has
// some context to predict the token).
//
// We rely on the fact that attention in the forward pass only looks at previous
// tokens here, so the logits returned for each token are an accurate representation
// of what the model would have predicted at that point.
//
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = n_ctx/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, log_probs, nll, nll2);
} else {
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
}
count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
printf(" %.4lf,", std::exp(nll / count));
fflush(stdout);
logits.clear();
tokens.insert(tokens.end(), s_tokens.begin(), s_tokens.end());
logit_history.insert(logit_history.end(), s_logit_history.begin(), s_logit_history.end());
prob_history.insert(prob_history.end(), s_prob_history.begin(), s_prob_history.end());
}
printf("\n");
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
nll2 = sqrt(nll2 / (count - 1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2 * ppl);
}
else {
printf("Unexpected negative standard deviation of log(prob)\n");
}
return {tokens, ppl, logit_history, prob_history};
return { tokens, ppl, logit_history, prob_history };
}
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {