kl-divergence: be able to save all logits to a file
This commit is contained in:
parent
7dcbe39d36
commit
fe5e95bb02
3 changed files with 103 additions and 2 deletions
|
@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
|
||||
params.logdir += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
} else if (arg == "--save-all-logits") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.logits_file = argv[i];
|
||||
} else if (arg == "--perplexity" || arg == "--all-logits") {
|
||||
params.logits_all = true;
|
||||
} else if (arg == "--ppl-stride") {
|
||||
|
|
|
@ -91,6 +91,7 @@ struct gpt_params {
|
|||
std::string input_suffix = ""; // string to suffix user inputs with
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
std::string logdir = ""; // directory in which to save YAML log files
|
||||
std::string logits_file = ""; // file for saving *all* logits
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
|
||||
|
|
|
@ -112,6 +112,43 @@ static results_log_softmax log_softmax(int n_vocab, const float * logits, int to
|
|||
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
||||
}
|
||||
|
||||
static inline int nearest_int(float fval) {
|
||||
//assert(fval <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
|
||||
float max_logit = logits[0];
|
||||
float min_logit = logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
max_logit = std::max(max_logit, logits[i]);
|
||||
min_logit = std::min(min_logit, logits[i]);
|
||||
}
|
||||
min_logit = std::max(min_logit, max_logit - 16);
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sum_exp += expf(logits[i] - max_logit);
|
||||
}
|
||||
const float log_sum_exp = log(sum_exp);
|
||||
const float min_log_prob = min_logit - max_logit - log_sum_exp;
|
||||
const float scale = (max_logit - min_logit)/65535.f;
|
||||
float * d = (float *)log_prob;
|
||||
d[0] = scale;
|
||||
d[1] = min_log_prob;
|
||||
log_prob += 4;
|
||||
if (scale) {
|
||||
const float inv_scale = 1/scale;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
|
||||
}
|
||||
} else {
|
||||
std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
|
||||
}
|
||||
return max_logit + log_sum_exp - logits[tok];
|
||||
}
|
||||
|
||||
static void process_logits(
|
||||
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
|
||||
double & nll, double & nll2, float * logit_history, float * prob_history
|
||||
|
@ -147,6 +184,37 @@ static void process_logits(
|
|||
}
|
||||
}
|
||||
|
||||
static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
|
||||
std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
|
||||
std::mutex mutex;
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
int counter = 0;
|
||||
auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
|
||||
double local_nll = 0;
|
||||
double local_nll2 = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int i = counter++;
|
||||
if (i >= n_token) {
|
||||
nll += local_nll; nll2 += local_nll2;
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
}
|
||||
};
|
||||
for (auto & w : workers) {
|
||||
w = std::thread(compute);
|
||||
}
|
||||
compute();
|
||||
for (auto & w : workers) {
|
||||
w.join();
|
||||
}
|
||||
out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
|
@ -294,6 +362,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
std::ofstream logits_stream;
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.open(params.logits_file.c_str());
|
||||
if (!logits_stream.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
|
||||
return {};
|
||||
}
|
||||
fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
|
||||
logits_stream.write("_logits_", 8);
|
||||
logits_stream.write((const char *)&n_ctx, sizeof(n_ctx));
|
||||
}
|
||||
|
||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||
|
||||
|
@ -336,6 +416,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
std::vector<uint16_t> log_probs;
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
|
||||
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
|
||||
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
log_probs.resize(n_ctx * nv);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
|
@ -398,8 +487,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
// process the entire prompt.
|
||||
const int first = n_ctx/2;
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
if (!params.logits_file.empty()) {
|
||||
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, log_probs, nll, nll2);
|
||||
} else {
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||
}
|
||||
count += n_ctx - first - 1;
|
||||
|
||||
// perplexity is e^(average negative log-likelihood)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue