rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -320,11 +320,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
const int n_ctx = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
if (int(tokens.size()) < 2*kv_size) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n", __func__, 2 * kv_size,
kv_size);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -340,13 +340,13 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history};
}
const int calc_chunk = n_ctx;
const int calc_chunk = kv_size;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), n_ctx, params.ppl_stride);
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n", __func__,
tokens.size(), kv_size, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
@ -414,8 +414,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.kv_size - params.ppl_stride + start, params.kv_size + start);
for (int j = kv_size - params.ppl_stride - 1; j < kv_size - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
@ -453,7 +453,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx);
std::ofstream logits_stream;
if (!params.logits_file.empty()) {
@ -464,7 +464,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
logits_stream.write("_logits_", 8);
logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
logits_stream.write(reinterpret_cast<const char *>(&kv_size), sizeof(kv_size));
}
auto tim1 = std::chrono::high_resolution_clock::now();
@ -475,9 +475,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
if (int(tokens.size()) < 2*kv_size) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n", __func__, 2 * kv_size,
kv_size);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -488,7 +488,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> prob_history;
prob_history.resize(tokens.size());
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk_max = tokens.size() / kv_size;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -498,11 +498,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll = 0.0;
double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int num_batches = (kv_size + n_batch - 1) / n_batch;
std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
logits.reserve((size_t)kv_size * n_vocab);
}
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
@ -513,14 +513,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (!params.logits_file.empty()) {
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
logits_stream.write((const char *)tokens.data(), n_chunk * kv_size * sizeof(tokens[0]));
const int nv = 2*((n_vocab + 1)/2) + 4;
log_probs.resize(n_ctx * nv);
log_probs.resize(kv_size * nv);
}
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const int start = i * kv_size;
const int end = start + kv_size;
const auto t_start = std::chrono::high_resolution_clock::now();
@ -566,7 +566,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
// We get the logits for all the tokens in the context window (params.n_ctx)
// We get the logits for all the tokens in the context window (params.kv_size)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
// calculate the perplexity over the last half of the window (so the model always has
// some context to predict the token).
@ -578,16 +578,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = n_ctx/2;
const int first = kv_size/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, kv_size - 1 - first,
workers, log_probs, nll, nll2);
} else {
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, kv_size - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
}
count += n_ctx - first - 1;
count += kv_size - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
@ -596,7 +596,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
printf("%8d %.4lf %4lf %4lf\n", i*kv_size, std::exp(nll / count), av, av2);
}
fflush(stdout);
@ -805,16 +805,16 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
double acc = 0.0f;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx);
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
llama_batch batch = llama_batch_init(kv_size, 0, max_seq);
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<float> batch_logits(n_vocab*kv_size);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
@ -832,7 +832,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
// each task has 4 unique seuqnce ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
while (n_cur + (int) hs_data[i1].required_tokens <= kv_size) {
auto & hs_cur = hs_data[i1];
const int s0 = 4*(i1 - i0);
@ -1082,16 +1082,16 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx);
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 128;
const int max_seq = 2*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
llama_batch batch = llama_batch_init(kv_size, 0, max_seq);
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<float> batch_logits(n_vocab*kv_size);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
@ -1108,7 +1108,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
llama_batch_clear(batch);
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
while (n_cur + (int) data[i1].required_tokens <= kv_size) {
const int s0 = 2*(i1 - i0);
if (s0 + 2 > max_seq) {
break;
@ -1434,16 +1434,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
printf("\ntask\tacc_norm\n");
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx);
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
llama_batch batch = llama_batch_init(kv_size, 0, max_seq);
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<float> batch_logits(n_vocab*kv_size);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
@ -1467,7 +1467,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
int s0 = 0;
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
while (n_cur + (int) tasks[i1].required_tokens <= kv_size) {
auto& cur_task = tasks[i1];
int num_answers = cur_task.seq_tokens.size();
@ -1620,11 +1620,11 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
}
}
uint32_t n_ctx;
in.read((char *)&n_ctx, sizeof(n_ctx));
if (n_ctx > llama_n_ctx(ctx)) {
fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
uint32_t kv_size;
in.read((char *)&kv_size, sizeof(kv_size));
if (kv_size > llama_kv_size(ctx)) {
fprintf(stderr, "%s: %s has been computed with %u, while the current KV Cache size is %d. Increase it with -kv and retry\n",
__func__, params.logits_file.c_str(), kv_size, params.kv_size);
}
int n_vocab, n_chunk;
@ -1638,22 +1638,22 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
}
std::vector<llama_token> tokens(n_ctx * n_chunk);
std::vector<llama_token> tokens(kv_size * n_chunk);
if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
return;
}
const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int num_batches = (kv_size + n_batch - 1)/n_batch;
const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
std::vector<uint16_t> log_probs_uint16(size_t(kv_size - 1 - kv_size/2) * nv);
std::vector<float> kld_values(size_t(kv_size - 1 - kv_size /2)*n_chunk);
std::vector<float> logits;
if (num_batches > 1) {
logits.reserve(n_ctx * n_vocab);
logits.reserve(kv_size * n_vocab);
}
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@ -1672,8 +1672,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
auto kld_ptr = kld_values.data();
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const int start = i * kv_size;
const int end = start + kv_size;
const auto t_start = std::chrono::high_resolution_clock::now();
@ -1726,11 +1726,11 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top\n");
}
const int first = n_ctx/2;
const int first = kv_size/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, kv_size - 1 - first,
workers, log_probs_uint16, kld, kld_ptr);
kld_ptr += n_ctx - 1 - first;
kld_ptr += kv_size - 1 - first;
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
@ -1788,12 +1788,12 @@ int main(int argc, char ** argv) {
}
params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
params.n_batch = std::min(params.n_batch, params.kv_size);
if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
params.n_ctx, params.n_ctx + params.ppl_stride/2);
params.n_ctx += params.ppl_stride/2;
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting KV size from %d to %d\n",
params.kv_size, params.kv_size + params.ppl_stride / 2);
params.kv_size += params.ppl_stride/2;
}
print_build_info();
@ -1823,9 +1823,9 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
if (params.kv_size > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
__func__, n_ctx_train, params.kv_size);
}
// print system information