rename n_ctx to kv_size
This commit is contained in:
parent
ef96e8b1f7
commit
606873401c
48 changed files with 403 additions and 393 deletions
|
@ -325,7 +325,7 @@ static void process_logits(
|
|||
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
|
||||
|
||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int kv_size = llama_kv_size(ctx);
|
||||
|
||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||
|
@ -336,17 +336,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
|||
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
|
||||
if (from_chunk > 0) {
|
||||
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
|
||||
if (size_t((from_chunk + 2)*kv_size) >= tokens.size()) {
|
||||
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
|
||||
return false;
|
||||
}
|
||||
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
|
||||
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
|
||||
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk * kv_size);
|
||||
tokens.erase(tokens.begin(), tokens.begin() + from_chunk * kv_size);
|
||||
}
|
||||
|
||||
if (int(tokens.size()) < 2*n_ctx) {
|
||||
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
|
||||
n_ctx);
|
||||
if (int(tokens.size()) < 2*kv_size) {
|
||||
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n", __func__, 2 * kv_size,
|
||||
kv_size);
|
||||
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
|
||||
return false;
|
||||
}
|
||||
|
@ -359,7 +359,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
|||
prob_history.resize(tokens.size());
|
||||
}
|
||||
|
||||
const int n_chunk_max = tokens.size() / n_ctx;
|
||||
const int n_chunk_max = tokens.size() / kv_size;
|
||||
|
||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
@ -373,16 +373,16 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
|||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||
const int num_batches = (kv_size + n_batch - 1) / n_batch;
|
||||
|
||||
std::vector<float> logits;
|
||||
if (compute_ppl && num_batches > 1) {
|
||||
logits.reserve((size_t)n_ctx * n_vocab);
|
||||
logits.reserve((size_t)kv_size * n_vocab);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
const int start = i * kv_size;
|
||||
const int end = start + kv_size;
|
||||
|
||||
std::vector<float> logits;
|
||||
|
||||
|
@ -431,11 +431,11 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
|||
}
|
||||
|
||||
if (compute_ppl) {
|
||||
const int first = n_ctx/2;
|
||||
const int first = kv_size / 2;
|
||||
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, kv_size - 1 - first,
|
||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||
count += n_ctx - first - 1;
|
||||
count += kv_size - first - 1;
|
||||
|
||||
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
||||
fflush(stdout);
|
||||
|
@ -553,7 +553,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
params.logits_all = true;
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
params.n_batch = std::min(params.n_batch, params.kv_size);
|
||||
|
||||
print_build_info();
|
||||
|
||||
|
@ -593,9 +593,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
if (params.n_ctx > n_ctx_train) {
|
||||
if (params.kv_size > n_ctx_train) {
|
||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, params.n_ctx);
|
||||
__func__, n_ctx_train, params.kv_size);
|
||||
}
|
||||
|
||||
// print system information
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue