llama : extend llama_kv_cache API
This commit is contained in:
parent
6952a460b9
commit
4d76d762ef
4 changed files with 84 additions and 32 deletions
|
@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
|
||||
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
|
|
@ -79,7 +79,9 @@ static void write_logfile(
|
|||
static std::vector<float> softmax(const std::vector<float>& logits) {
|
||||
std::vector<float> probs(logits.size());
|
||||
float max_logit = logits[0];
|
||||
for (float v : logits) max_logit = std::max(max_logit, v);
|
||||
for (float v : logits) {
|
||||
max_logit = std::max(max_logit, v);
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (size_t i = 0; i < logits.size(); i++) {
|
||||
// Subtract the maximum logit value from the current logit value for numerical stability
|
||||
|
@ -88,15 +90,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
|
|||
sum_exp += exp_logit;
|
||||
probs[i] = exp_logit;
|
||||
}
|
||||
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
|
||||
for (size_t i = 0; i < probs.size(); i++) {
|
||||
probs[i] /= sum_exp;
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
|
||||
float max_logit = logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
max_logit = std::max(max_logit, logits[i]);
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sum_exp += expf(logits[i] - max_logit);
|
||||
}
|
||||
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
||||
}
|
||||
|
||||
|
@ -107,7 +115,8 @@ static void process_logits(
|
|||
std::mutex mutex;
|
||||
int counter = 0;
|
||||
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
|
||||
double local_nll = 0, local_nll2 = 0;
|
||||
double local_nll = 0;
|
||||
double local_nll2 = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int i = counter++;
|
||||
|
@ -125,10 +134,13 @@ static void process_logits(
|
|||
prob_history[i] = results.prob;
|
||||
}
|
||||
};
|
||||
for (auto & w : workers) w = std::thread(compute);
|
||||
for (auto & w : workers) {
|
||||
w = std::thread(compute);
|
||||
}
|
||||
compute();
|
||||
for (auto & w : workers) w.join();
|
||||
|
||||
for (auto & w : workers) {
|
||||
w.join();
|
||||
}
|
||||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
|
@ -151,8 +163,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
return {std::move(tokens), 0., {}, {}};
|
||||
}
|
||||
|
||||
std::vector<float> logit_history;
|
||||
std::vector<float> prob_history;
|
||||
std::vector<float> logit_history;
|
||||
std::vector<float> prob_history;
|
||||
|
||||
logit_history.resize(tokens.size());
|
||||
prob_history.resize(tokens.size());
|
||||
|
@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
query_embd.resize(32);
|
||||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue