YAML result logging + preset script (#2657)
This commit is contained in:
parent
75fafcbccc
commit
6b73ef1201
8 changed files with 700 additions and 42 deletions
|
@ -3,16 +3,79 @@
|
|||
#include "build-info.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
struct results_perplexity {
|
||||
std::vector<llama_token> tokens;
|
||||
double ppl_value;
|
||||
std::vector<float> logits;
|
||||
std::vector<float> probs;
|
||||
};
|
||||
|
||||
struct results_log_softmax {
|
||||
double log_softmax;
|
||||
float logit;
|
||||
float prob;
|
||||
};
|
||||
|
||||
void write_logfile(const llama_context * ctx, const gpt_params & params,
|
||||
const llama_model * model, const struct results_perplexity & results) {
|
||||
|
||||
if (params.logdir.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (params.hellaswag) {
|
||||
fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string timestamp = get_sortable_timestamp();
|
||||
|
||||
const bool success = create_directory_with_parents(params.logdir);
|
||||
if (!success) {
|
||||
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
|
||||
__func__, params.logdir.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string logfile_path = params.logdir + timestamp + ".yml";
|
||||
FILE * logfile = fopen(logfile_path.c_str(), "w");
|
||||
|
||||
if (logfile == NULL) {
|
||||
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(logfile, "binary: main\n");
|
||||
char model_desc[128];
|
||||
llama_model_desc(model, model_desc, sizeof(model_desc));
|
||||
dump_non_result_info_yaml(logfile, params, ctx, timestamp, results.tokens, model_desc);
|
||||
|
||||
fprintf(logfile, "\n");
|
||||
fprintf(logfile, "######################\n");
|
||||
fprintf(logfile, "# Perplexity Results #\n");
|
||||
fprintf(logfile, "######################\n");
|
||||
fprintf(logfile, "\n");
|
||||
|
||||
dump_vector_float_yaml(logfile, "logits", results.logits);
|
||||
fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
|
||||
dump_vector_float_yaml(logfile, "probs", results.probs);
|
||||
|
||||
llama_dump_timing_info_yaml(logfile, ctx);
|
||||
fclose(logfile);
|
||||
}
|
||||
|
||||
std::vector<float> softmax(const std::vector<float>& logits) {
|
||||
std::vector<float> probs(logits.size());
|
||||
float max_logit = logits[0];
|
||||
|
@ -29,20 +92,20 @@ std::vector<float> softmax(const std::vector<float>& logits) {
|
|||
return probs;
|
||||
}
|
||||
|
||||
float log_softmax(int n_vocab, const float * logits, int tok) {
|
||||
results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
|
||||
float max_logit = logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
|
||||
return logits[tok] - max_logit - log(sum_exp);
|
||||
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
||||
}
|
||||
|
||||
void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
|
||||
double& nll, double& nll2) {
|
||||
void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
|
||||
double & nll, double & nll2, float * logit_history, float * prob_history) {
|
||||
|
||||
std::mutex mutex;
|
||||
int counter = 0;
|
||||
auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
|
||||
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
|
||||
double local_nll = 0, local_nll2 = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
|
@ -52,34 +115,43 @@ void process_logits(int n_vocab, const float * logits, const int * tokens, int n
|
|||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
double v = -log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
|
||||
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
|
||||
const double v = -results.log_softmax;
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
|
||||
logit_history[i] = results.logit;
|
||||
prob_history[i] = results.prob;
|
||||
}
|
||||
};
|
||||
for (auto& w : workers) w = std::thread(compute);
|
||||
for (auto & w : workers) w = std::thread(compute);
|
||||
compute();
|
||||
for (auto& w : workers) w.join();
|
||||
for (auto & w : workers) w.join();
|
||||
|
||||
}
|
||||
|
||||
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
||||
if (params.ppl_stride <= 0) {
|
||||
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
|
||||
return;
|
||||
}
|
||||
|
||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||
const bool add_bos = is_spm;
|
||||
|
||||
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
std::vector<float> logit_history;
|
||||
std::vector<float> prob_history;
|
||||
|
||||
logit_history.resize(tokens.size());
|
||||
prob_history.resize(tokens.size());
|
||||
|
||||
if (params.ppl_stride <= 0) {
|
||||
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
const int calc_chunk = params.n_ctx;
|
||||
|
||||
|
@ -88,7 +160,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
if (int(tokens.size()) <= calc_chunk) {
|
||||
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
|
||||
tokens.size(), params.n_ctx, params.ppl_stride);
|
||||
return;
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
|
||||
|
@ -120,7 +192,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
// save original token and restore it after eval
|
||||
|
@ -161,6 +233,8 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
logits.begin() + (j + 1) * n_vocab);
|
||||
|
||||
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
||||
logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
|
||||
prob_history[start + j + 1] = prob;
|
||||
|
||||
nll += -std::log(prob);
|
||||
++count;
|
||||
|
@ -174,12 +248,14 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
fflush(stdout);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
return {tokens, std::exp(nll / count), logit_history, prob_history};
|
||||
}
|
||||
|
||||
void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||
results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
perplexity_v2(ctx, params);
|
||||
return;
|
||||
return perplexity_v2(ctx, params);
|
||||
}
|
||||
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
|
@ -193,11 +269,17 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
auto tim2 = std::chrono::high_resolution_clock::now();
|
||||
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
|
||||
std::vector<float> logit_history;
|
||||
logit_history.resize(tokens.size());
|
||||
|
||||
std::vector<float> prob_history;
|
||||
prob_history.resize(tokens.size());
|
||||
|
||||
const int n_chunk_max = tokens.size() / params.n_ctx;
|
||||
|
||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||
|
@ -236,7 +318,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
|
@ -272,7 +354,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
// last 256 tokens. Then, we split the input up into context window size chunks to
|
||||
// process the entire prompt.
|
||||
const int first = std::min(512, params.n_ctx/2);
|
||||
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
|
||||
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
|
||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||
count += params.n_ctx - first - 1;
|
||||
|
||||
// perplexity is e^(average negative log-likelihood)
|
||||
|
@ -287,16 +370,19 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
fflush(stdout);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
nll2 /= count;
|
||||
nll /= count;
|
||||
const double ppl = exp(nll);
|
||||
nll2 -= nll * nll;
|
||||
if (nll2 > 0) {
|
||||
nll2 = sqrt(nll2/(count-1));
|
||||
double ppl = exp(nll);
|
||||
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
|
||||
} else {
|
||||
printf("Unexpected negative standard deviation of log(prob)\n");
|
||||
}
|
||||
|
||||
return {tokens, ppl, logit_history, prob_history};
|
||||
}
|
||||
|
||||
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
|
||||
|
@ -604,13 +690,16 @@ int main(int argc, char ** argv) {
|
|||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
|
||||
struct results_perplexity results;
|
||||
if (params.hellaswag) {
|
||||
hellaswag_score(ctx, params);
|
||||
} else {
|
||||
perplexity(ctx, params);
|
||||
results = perplexity(ctx, params);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
write_logfile(ctx, params, model, results);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue