Compute perplexity over prompt
This commit is contained in:
parent
d3f202d57b
commit
e94bd9c7b9
3 changed files with 67 additions and 15 deletions
62
main.cpp
62
main.cpp
|
@ -547,7 +547,7 @@ bool llama_eval(
|
||||||
static void * buf = malloc(buf_size);
|
static void * buf = malloc(buf_size);
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
|
||||||
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
|
@ -747,6 +747,49 @@ bool llama_eval(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<double> softmax(const std::vector<float>& logits) {
|
||||||
|
std::vector<double> probs(logits.size());
|
||||||
|
float max_logit = logits[0];
|
||||||
|
for (float v : logits) max_logit = std::max(max_logit, v);
|
||||||
|
double sum_exp = 0.0;
|
||||||
|
for (size_t i = 0; i < logits.size(); i++) {
|
||||||
|
// Subtract the maximum logit value from the current logit value for numerical stability
|
||||||
|
float logit = logits[i] - max_logit;
|
||||||
|
double exp_logit = std::exp(logit);
|
||||||
|
sum_exp += exp_logit;
|
||||||
|
probs[i] = exp_logit;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_params ¶ms, size_t mem_per_token) {
|
||||||
|
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||||
|
// Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||||
|
// Output: `perplexity: 13.5106 [114/114]`
|
||||||
|
std::vector<gpt_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true);
|
||||||
|
|
||||||
|
double nll = 0.0;
|
||||||
|
int seq_count = tokens.size() / params.n_ctx;
|
||||||
|
for (int i = 0; i < seq_count; ++i) {
|
||||||
|
int start = i * params.n_ctx;
|
||||||
|
int end = start + params.n_ctx - 1;
|
||||||
|
std::vector<gpt_vocab::id> embd(tokens.begin() + start, tokens.begin() + end);
|
||||||
|
std::vector<float> logits;
|
||||||
|
if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token)) {
|
||||||
|
fprintf(stderr, "Failed to predict\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Calculate probability of next token, given the previous ones.
|
||||||
|
double prob = softmax(logits)[tokens[end]];
|
||||||
|
nll += -std::log(prob);
|
||||||
|
// perplexity is e^(average negative log-likelihood)
|
||||||
|
printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / (i + 1)), i + 1, seq_count);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
static bool is_interacting = false;
|
static bool is_interacting = false;
|
||||||
|
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
|
@ -830,13 +873,22 @@ int main(int argc, char ** argv) {
|
||||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> logits;
|
||||||
|
|
||||||
|
// determine the required inference memory per token:
|
||||||
|
size_t mem_per_token = 0;
|
||||||
|
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||||
|
|
||||||
|
if (params.perplexity) {
|
||||||
|
perplexity(vocab, model, params, mem_per_token);
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
int64_t t_sample_us = 0;
|
int64_t t_sample_us = 0;
|
||||||
int64_t t_predict_us = 0;
|
int64_t t_predict_us = 0;
|
||||||
|
|
||||||
std::vector<float> logits;
|
|
||||||
|
|
||||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||||
params.prompt.insert(0, 1, ' ');
|
params.prompt.insert(0, 1, ' ');
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
|
@ -881,10 +933,6 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> embd;
|
std::vector<gpt_vocab::id> embd;
|
||||||
|
|
||||||
// determine the required inference memory per token:
|
|
||||||
size_t mem_per_token = 0;
|
|
||||||
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
|
||||||
|
|
||||||
int last_n_size = params.repeat_last_n;
|
int last_n_size = params.repeat_last_n;
|
||||||
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
|
10
utils.cpp
10
utils.cpp
|
@ -44,7 +44,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
std::copy(std::istreambuf_iterator<char>(file),
|
std::copy(std::istreambuf_iterator<char>(file),
|
||||||
std::istreambuf_iterator<char>(),
|
std::istreambuf_iterator<char>(),
|
||||||
back_inserter(params.prompt));
|
back_inserter(params.prompt));
|
||||||
|
|
||||||
} else if (arg == "-n" || arg == "--n_predict") {
|
} else if (arg == "-n" || arg == "--n_predict") {
|
||||||
params.n_predict = std::stoi(argv[++i]);
|
params.n_predict = std::stoi(argv[++i]);
|
||||||
} else if (arg == "--top_k") {
|
} else if (arg == "--top_k") {
|
||||||
|
@ -72,6 +71,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||||
params.antiprompt = argv[++i];
|
params.antiprompt = argv[++i];
|
||||||
|
} else if (arg == "--perplexity") {
|
||||||
|
params.perplexity = true;
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
gpt_print_usage(argc, argv, params);
|
gpt_print_usage(argc, argv, params);
|
||||||
exit(0);
|
exit(0);
|
||||||
|
@ -109,6 +110,7 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
|
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
@ -322,9 +324,9 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
|
||||||
while (i > 0) {
|
while (i > 0) {
|
||||||
gpt_vocab::id token_id = prev[i];
|
gpt_vocab::id token_id = prev[i];
|
||||||
if (token_id == 0) {
|
if (token_id == 0) {
|
||||||
// TODO: Return error or something more meaningful
|
// TODO: Return error or something more meaningful
|
||||||
printf("failed to tokenize string!\n");
|
printf("failed to tokenize string at %d!\n", i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
res.push_back(token_id);
|
res.push_back(token_id);
|
||||||
auto token = (*vocab.id_to_token.find(token_id)).second;
|
auto token = (*vocab.id_to_token.find(token_id)).second;
|
||||||
|
|
2
utils.h
2
utils.h
|
@ -35,6 +35,8 @@ struct gpt_params {
|
||||||
bool interactive = false; // interactive mode
|
bool interactive = false; // interactive mode
|
||||||
bool interactive_start = false; // reverse prompt immediately
|
bool interactive_start = false; // reverse prompt immediately
|
||||||
std::string antiprompt = ""; // string upon seeing which more user input is prompted
|
std::string antiprompt = ""; // string upon seeing which more user input is prompted
|
||||||
|
|
||||||
|
bool perplexity = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue