llama : update llama_kv_self API

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-14 16:47:34 +02:00
parent fd05ab87aa
commit 17b363afd3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
30 changed files with 387 additions and 205 deletions

View file

@ -299,8 +299,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
@ -362,7 +360,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(kv);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
@ -452,8 +450,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
@ -550,7 +546,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(kv);
llama_kv_self_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -745,8 +741,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
// Calculates hellaswag score (acc_norm) from prompt
//
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@ -929,7 +923,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
return;
}
llama_kv_cache_clear(kv);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1090,8 +1084,6 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
constexpr int k_min_trailing_ctx = 3;
auto data = load_winogrande_from_csv(params.prompt);
@ -1210,7 +1202,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
return;
}
llama_kv_cache_clear(kv);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1396,8 +1388,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
std::istringstream strstream(params.prompt);
uint32_t n_task;
strstream.read((char *)&n_task, sizeof(n_task));
@ -1584,7 +1574,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
return;
}
llama_kv_cache_clear(kv);
llama_kv_self_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1681,8 +1671,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
if (params.logits_file.empty()) {
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
return;
@ -1776,7 +1764,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
// clear the KV cache
llama_kv_cache_clear(kv);
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);