imatrix: speedup by avoiding unnecessary allocations and copies
This commit is contained in:
parent
97c1549808
commit
cdeac23ef5
1 changed files with 15 additions and 5 deletions
|
@ -288,12 +288,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||||
|
|
||||||
|
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
|
||||||
|
std::vector<float> logits;
|
||||||
|
if (num_batches > 1) {
|
||||||
|
logits.reserve((size_t)n_ctx * n_vocab);
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
const int end = start + n_ctx;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
@ -321,8 +326,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
// restore the original token in case it was set to BOS
|
// restore the original token in case it was set to BOS
|
||||||
tokens[batch_start] = token_org;
|
tokens[batch_start] = token_org;
|
||||||
|
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
if (num_batches > 1) {
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
@ -339,12 +346,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const int first = n_ctx/2;
|
const int first = n_ctx/2;
|
||||||
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||||
|
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||||
count += n_ctx - first - 1;
|
count += n_ctx - first - 1;
|
||||||
|
|
||||||
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
logits.clear();
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue