rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -325,7 +325,7 @@ static void process_logits(
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -336,17 +336,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (from_chunk > 0) {
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
if (size_t((from_chunk + 2)*kv_size) >= tokens.size()) {
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
return false;
}
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk * kv_size);
tokens.erase(tokens.begin(), tokens.begin() + from_chunk * kv_size);
}
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
n_ctx);
if (int(tokens.size()) < 2*kv_size) {
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n", __func__, 2 * kv_size,
kv_size);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return false;
}
@ -359,7 +359,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
prob_history.resize(tokens.size());
}
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk_max = tokens.size() / kv_size;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -373,16 +373,16 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int num_batches = (kv_size + n_batch - 1) / n_batch;
std::vector<float> logits;
if (compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
logits.reserve((size_t)kv_size * n_vocab);
}
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const int start = i * kv_size;
const int end = start + kv_size;
std::vector<float> logits;
@ -431,11 +431,11 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
}
if (compute_ppl) {
const int first = n_ctx/2;
const int first = kv_size / 2;
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, kv_size - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
count += kv_size - first - 1;
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
fflush(stdout);
@ -553,7 +553,7 @@ int main(int argc, char ** argv) {
}
params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
params.n_batch = std::min(params.n_batch, params.kv_size);
print_build_info();
@ -593,9 +593,9 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
if (params.kv_size > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
__func__, n_ctx_train, params.kv_size);
}
// print system information