llama : extend llama_kv_cache API

This commit is contained in:
Georgi Gerganov 2023-09-18 15:53:03 +03:00
parent 6952a460b9
commit 4d76d762ef
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 84 additions and 32 deletions

View file

@ -79,7 +79,9 @@ static void write_logfile(
static std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size());
float max_logit = logits[0];
for (float v : logits) max_logit = std::max(max_logit, v);
for (float v : logits) {
max_logit = std::max(max_logit, v);
}
double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability
@ -88,15 +90,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
sum_exp += exp_logit;
probs[i] = exp_logit;
}
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
for (size_t i = 0; i < probs.size(); i++) {
probs[i] /= sum_exp;
}
return probs;
}
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
for (int i = 1; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]);
}
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
for (int i = 0; i < n_vocab; ++i) {
sum_exp += expf(logits[i] - max_logit);
}
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
}
@ -107,7 +115,8 @@ static void process_logits(
std::mutex mutex;
int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
double local_nll = 0, local_nll2 = 0;
double local_nll = 0;
double local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int i = counter++;
@ -125,10 +134,13 @@ static void process_logits(
prob_history[i] = results.prob;
}
};
for (auto & w : workers) w = std::thread(compute);
for (auto & w : workers) {
w = std::thread(compute);
}
compute();
for (auto & w : workers) w.join();
for (auto & w : workers) {
w.join();
}
}
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
@ -151,8 +163,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {std::move(tokens), 0., {}, {}};
}
std::vector<float> logit_history;
std::vector<float> prob_history;
std::vector<float> logit_history;
std::vector<float> prob_history;
logit_history.resize(tokens.size());
prob_history.resize(tokens.size());
@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_embd.resize(32);
}
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);