kl-divergence: be able to save all logits to a file

This commit is contained in:
Iwan Kawrakow 2024-01-22 06:48:56 +02:00
parent 7dcbe39d36
commit fe5e95bb02
3 changed files with 103 additions and 2 deletions

View file

@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
params.logdir += DIRECTORY_SEPARATOR;
}
} else if (arg == "--save-all-logits") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.logits_file = argv[i];
} else if (arg == "--perplexity" || arg == "--all-logits") {
params.logits_all = true;
} else if (arg == "--ppl-stride") {

View file

@ -91,6 +91,7 @@ struct gpt_params {
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
std::string logits_file = ""; // file for saving *all* logits
std::vector<llama_model_kv_override> kv_overrides;

View file

@ -112,6 +112,43 @@ static results_log_softmax log_softmax(int n_vocab, const float * logits, int to
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
}
static inline int nearest_int(float fval) {
//assert(fval <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
}
static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
float max_logit = logits[0];
float min_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]);
min_logit = std::min(min_logit, logits[i]);
}
min_logit = std::max(min_logit, max_logit - 16);
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) {
sum_exp += expf(logits[i] - max_logit);
}
const float log_sum_exp = log(sum_exp);
const float min_log_prob = min_logit - max_logit - log_sum_exp;
const float scale = (max_logit - min_logit)/65535.f;
float * d = (float *)log_prob;
d[0] = scale;
d[1] = min_log_prob;
log_prob += 4;
if (scale) {
const float inv_scale = 1/scale;
for (int i = 0; i < n_vocab; ++i) {
log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
}
} else {
std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
}
return max_logit + log_sum_exp - logits[tok];
}
static void process_logits(
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
double & nll, double & nll2, float * logit_history, float * prob_history
@ -147,6 +184,37 @@ static void process_logits(
}
}
static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
std::mutex mutex;
const int nv = 2*((n_vocab + 1)/2) + 4;
int counter = 0;
auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
double local_nll = 0;
double local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int i = counter++;
if (i >= n_token) {
nll += local_nll; nll2 += local_nll2;
break;
}
lock.unlock();
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
local_nll += v;
local_nll2 += v*v;
}
};
for (auto & w : workers) {
w = std::thread(compute);
}
compute();
for (auto & w : workers) {
w.join();
}
out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
}
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
@ -294,6 +362,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
std::ofstream logits_stream;
if (!params.logits_file.empty()) {
logits_stream.open(params.logits_file.c_str());
if (!logits_stream.is_open()) {
fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
return {};
}
fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
logits_stream.write("_logits_", 8);
logits_stream.write((const char *)&n_ctx, sizeof(n_ctx));
}
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -336,6 +416,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
std::vector<uint16_t> log_probs;
if (!params.logits_file.empty()) {
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
const int nv = 2*((n_vocab + 1)/2) + 4;
log_probs.resize(n_ctx * nv);
}
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
@ -398,8 +487,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// process the entire prompt.
const int first = n_ctx/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, log_probs, nll, nll2);
} else {
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
}
count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)