imatrix: add --no-ppl option to skip PPL calculations altogether
This commit is contained in:
parent
cdeac23ef5
commit
3aa56562c0
1 changed files with 39 additions and 25 deletions
|
@ -248,7 +248,7 @@ static void process_logits(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {
|
||||||
|
|
||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
@ -269,10 +269,12 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> logit_history;
|
std::vector<float> logit_history;
|
||||||
logit_history.resize(tokens.size());
|
|
||||||
|
|
||||||
std::vector<float> prob_history;
|
std::vector<float> prob_history;
|
||||||
|
|
||||||
|
if (compute_ppl) {
|
||||||
|
logit_history.resize(tokens.size());
|
||||||
prob_history.resize(tokens.size());
|
prob_history.resize(tokens.size());
|
||||||
|
}
|
||||||
|
|
||||||
const int n_chunk_max = tokens.size() / n_ctx;
|
const int n_chunk_max = tokens.size() / n_ctx;
|
||||||
|
|
||||||
|
@ -291,7 +293,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
if (num_batches > 1) {
|
if (compute_ppl && num_batches > 1) {
|
||||||
logits.reserve((size_t)n_ctx * n_vocab);
|
logits.reserve((size_t)n_ctx * n_vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,7 +328,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
// restore the original token in case it was set to BOS
|
// restore the original token in case it was set to BOS
|
||||||
tokens[batch_start] = token_org;
|
tokens[batch_start] = token_org;
|
||||||
|
|
||||||
if (num_batches > 1) {
|
if (compute_ppl && num_batches > 1) {
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
}
|
}
|
||||||
|
@ -345,6 +347,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (compute_ppl) {
|
||||||
const int first = n_ctx/2;
|
const int first = n_ctx/2;
|
||||||
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||||
|
@ -356,8 +359,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
logits.clear();
|
logits.clear();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
if (compute_ppl) {
|
||||||
nll2 /= count;
|
nll2 /= count;
|
||||||
nll /= count;
|
nll /= count;
|
||||||
const double ppl = exp(nll);
|
const double ppl = exp(nll);
|
||||||
|
@ -368,6 +373,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
} else {
|
} else {
|
||||||
printf("Unexpected negative standard deviation of log(prob)\n");
|
printf("Unexpected negative standard deviation of log(prob)\n");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -375,6 +381,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
|
|
||||||
StatParams sparams;
|
StatParams sparams;
|
||||||
|
bool compute_ppl = true;
|
||||||
std::vector<char*> args;
|
std::vector<char*> args;
|
||||||
args.push_back(argv[0]);
|
args.push_back(argv[0]);
|
||||||
int iarg = 1;
|
int iarg = 1;
|
||||||
|
@ -391,13 +398,20 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
else if (arg == "--verbosity") {
|
else if (arg == "--verbosity") {
|
||||||
sparams.verbosity = std::stoi(argv[++iarg]);
|
sparams.verbosity = std::stoi(argv[++iarg]);
|
||||||
|
} else if (arg == "--no-ppl") {
|
||||||
|
compute_ppl = false;
|
||||||
} else {
|
} else {
|
||||||
args.push_back(argv[iarg]);
|
args.push_back(argv[iarg]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (iarg < argc) {
|
if (iarg < argc) {
|
||||||
|
std::string arg{argv[iarg]};
|
||||||
|
if (arg == "--no-ppl") {
|
||||||
|
compute_ppl = false;
|
||||||
|
} else {
|
||||||
args.push_back(argv[iarg]);
|
args.push_back(argv[iarg]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
params.n_batch = 512;
|
params.n_batch = 512;
|
||||||
|
@ -458,7 +472,7 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OK = compute_imatrix(ctx, params);
|
bool OK = compute_imatrix(ctx, params, compute_ppl);
|
||||||
if (!OK) {
|
if (!OK) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue